mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:35:40 +08:00
[Model] Mamba2 varlen refactor (#21467)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
parent
633f943e30
commit
2b6b1d7809
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
@ -185,9 +185,14 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
# varlen has implicit batch=1
|
||||
dt2 = dt2.squeeze(0)
|
||||
X2 = X2.squeeze(0)
|
||||
B2 = B2.squeeze(0)
|
||||
C2 = C2.squeeze(0)
|
||||
yield ([Y_min[s, IND_S[s]:IND_E[s]]
|
||||
for s in range(num_examples)] if return_naive_ref else None,
|
||||
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
@ -198,7 +203,7 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
||||
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
itype):
|
||||
|
||||
# this tests the kernels on a single example (no batching)
|
||||
# this tests the kernels on a single example (bs=1)
|
||||
|
||||
# TODO: the bfloat16 case requires higher thresholds. To be investigated
|
||||
|
||||
@ -219,23 +224,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
|
||||
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
# varlen has implicit batch=1
|
||||
X = X.squeeze(0)
|
||||
dt = dt.squeeze(0)
|
||||
A = A.squeeze(0)
|
||||
B = B.squeeze(0)
|
||||
C = C.squeeze(0)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True,
|
||||
out=Y)
|
||||
final_state = mamba_chunk_scan_combined_varlen(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
out=Y)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||
|
||||
# just test the last head
|
||||
# NOTE, in the kernel we always cast states to fp32
|
||||
torch.testing.assert_close(final_state[:, -1],
|
||||
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
|
||||
final_state_min[:, -1].to(torch.float32),
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
@ -300,7 +322,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined(
|
||||
new_states = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
@ -312,7 +334,6 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
out=Y,
|
||||
)
|
||||
@ -321,7 +342,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
for i in range(num_examples):
|
||||
|
||||
# just test one dim and dstate
|
||||
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
|
||||
Y_min_eg = Y_min[i][:, 0, 0]
|
||||
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
|
||||
|
||||
@ -386,7 +407,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined(
|
||||
state_ref = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
@ -398,7 +419,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_ref,
|
||||
)
|
||||
@ -414,27 +434,27 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(chunked_seqlens), device=device),
|
||||
chunked_seqlens,
|
||||
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
|
||||
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
|
||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
# fmt: off
|
||||
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
|
||||
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined(
|
||||
partial_state = mamba_chunk_scan_combined_varlen(
|
||||
X_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
@ -446,7 +466,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
seq_idx=chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_partial,
|
||||
)
|
||||
@ -461,29 +480,28 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||
remaining_chunked_seqlens,
|
||||
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
|
||||
torch.int32)
|
||||
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
|
||||
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
],
|
||||
dim=1)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
|
||||
dim=0)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
@ -498,7 +516,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
remaining_chunked_cu_seqlens[-1])
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined(
|
||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||
remaining_X_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
@ -510,7 +528,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
seq_idx=remaining_chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=partial_state,
|
||||
out=Y_chunked,
|
||||
)
|
||||
@ -518,17 +535,17 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
|
||||
# kernel chunked is same as kernel overall
|
||||
for i in range(num_sequences):
|
||||
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, :chunked_seqlens[i], ...],
|
||||
Y_ref_seq[:, :chunked_seqlens[i], ...],
|
||||
Y_seq[:chunked_seqlens[i], ...],
|
||||
Y_ref_seq[:chunked_seqlens[i], ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, chunked_seqlens[i]:, ...],
|
||||
Y_ref_seq[:, chunked_seqlens[i]:, ...],
|
||||
Y_seq[chunked_seqlens[i]:, ...],
|
||||
Y_ref_seq[chunked_seqlens[i]:, ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
@ -504,6 +504,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@ -545,6 +546,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
@ -570,9 +572,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@ -620,15 +619,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt_p.unsqueeze(0),
|
||||
dt_p,
|
||||
self.A,
|
||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
@ -639,17 +638,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
|
||||
@ -427,7 +427,7 @@ def causal_conv1d_fn(
|
||||
batch_ptr = metadata.batch_ptr
|
||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||
else:
|
||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||
seqlens = query_start_loc.diff().to('cpu')
|
||||
args = seqlens
|
||||
MAX_NUM_PROGRAMS = 1024
|
||||
|
||||
|
||||
@ -99,34 +99,28 @@ def _bmm_chunk_fwd_kernel(
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size,
|
||||
K,
|
||||
ngroups,
|
||||
stride_a_batch,
|
||||
stride_a_seqlen,
|
||||
stride_a_head,
|
||||
stride_ak,
|
||||
stride_b_batch,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_bk,
|
||||
stride_out_batch,
|
||||
stride_out_chunk,
|
||||
stride_out_head,
|
||||
stride_outm,
|
||||
stride_outn,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
ngroups: tl.constexpr,
|
||||
stride_a_seqlen: tl.int64,
|
||||
stride_a_head: tl.int64,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_ch = tl.program_id(axis=2).to(tl.int64)
|
||||
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
@ -135,10 +129,10 @@ def _bmm_chunk_fwd_kernel(
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -150,6 +144,8 @@ def _bmm_chunk_fwd_kernel(
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
@ -165,18 +161,19 @@ def _bmm_chunk_fwd_kernel(
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if HAS_SEQ_IDX:
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
|
||||
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
# Zero out the results that are not from the same request
|
||||
# in the varlen batch
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||
offs_n[None, :] * stride_outn)
|
||||
tl.store(out_ptrs,
|
||||
@ -185,78 +182,61 @@ def _bmm_chunk_fwd_kernel(
|
||||
(offs_n[None, :] < chunk_size))
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a,
|
||||
b,
|
||||
chunk_size,
|
||||
seq_idx=None,
|
||||
causal=False,
|
||||
output_dtype=None):
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
# Check constraints.
|
||||
has_groups = a.dim() == 4
|
||||
if not has_groups:
|
||||
batch, seqlen, k = a.shape
|
||||
else:
|
||||
batch, seqlen, ngroups, k = a.shape
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if a.stride(-1) != 1 and a.stride(1) != 1:
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(1) != 1:
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty(
|
||||
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
|
||||
(batch, nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device,
|
||||
dtype=out_dtype)
|
||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device,
|
||||
dtype=out_dtype)
|
||||
dot_dtype = (tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16
|
||||
or b.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
|
||||
if not has_groups else nchunks * ngroups)
|
||||
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
seq_idx,
|
||||
seqlen,
|
||||
chunk_size,
|
||||
k,
|
||||
ngroups if has_groups else 1,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
0 if not has_groups else a.stride(2),
|
||||
a.stride(-1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
0 if not has_groups else b.stride(2),
|
||||
b.stride(-1),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
0 if not has_groups else out.stride(2),
|
||||
out.stride(-2),
|
||||
out.stride(-1),
|
||||
*((seq_idx.stride(0),
|
||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
causal,
|
||||
dot_dtype,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
seq_idx_ptr=seq_idx,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
ngroups=ngroups,
|
||||
stride_a_seqlen=a.stride(0),
|
||||
stride_a_head=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_b_seqlen=b.stride(0),
|
||||
stride_b_head=b.stride(1),
|
||||
stride_bk=b.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
@ -114,7 +113,6 @@ def _chunk_scan_fwd_kernel(
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
out_x_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
@ -126,60 +124,49 @@ def _chunk_scan_fwd_kernel(
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
chunk_size,
|
||||
hdim,
|
||||
dstate,
|
||||
batch,
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_cb_batch,
|
||||
stride_cb_chunk,
|
||||
stride_cb_head,
|
||||
stride_cb_csize_m,
|
||||
stride_cb_csize_k,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_z_batch,
|
||||
stride_z_seqlen,
|
||||
stride_z_head,
|
||||
stride_z_hdim,
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_head,
|
||||
stride_out_hdim,
|
||||
stride_dt_batch,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
stride_C_batch,
|
||||
stride_C_seqlen,
|
||||
stride_C_head,
|
||||
stride_C_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_init_states_batch,
|
||||
stride_init_states_head,
|
||||
stride_init_states_hdim,
|
||||
stride_init_states_dstate,
|
||||
stride_D_head,
|
||||
stride_cb_chunk: tl.int64,
|
||||
stride_cb_head: tl.int64,
|
||||
stride_cb_csize_m: tl.int64,
|
||||
stride_cb_csize_k: tl.constexpr,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_z_seqlen: tl.int64,
|
||||
stride_z_head: tl.int64,
|
||||
stride_z_hdim: tl.constexpr,
|
||||
stride_out_seqlen: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_hdim: tl.constexpr,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
stride_D_head: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
@ -187,9 +174,7 @@ def _chunk_scan_fwd_kernel(
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
if not HAS_INITSTATES:
|
||||
c_idx = pid_c
|
||||
c_off = 0
|
||||
@ -201,53 +186,51 @@ def _chunk_scan_fwd_kernel(
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + (
|
||||
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_cb_head
|
||||
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||
mask=c_idx >= 1,
|
||||
other=0)
|
||||
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||
mask=c_idx >= 1,
|
||||
other=0)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if ((c_off == 0) and
|
||||
(seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
@ -256,7 +239,6 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
# - handle chunk state limit
|
||||
if HAS_INITSTATES:
|
||||
|
||||
# have to split this if otherwise compilation will have problems
|
||||
dA_cs_m_boundary = 0.0
|
||||
|
||||
@ -296,13 +278,11 @@ def _chunk_scan_fwd_kernel(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
else:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
if not HAS_INITSTATES:
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
@ -319,18 +299,15 @@ def _chunk_scan_fwd_kernel(
|
||||
prev_states_ptrs = prev_states_ptr + (
|
||||
offs_n[None, :] * prev_states_hdim +
|
||||
offs_k_dstate[:, None] * prev_states_dstate)
|
||||
if HAS_SEQ_IDX:
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
|
||||
0.0)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||
else:
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
@ -416,15 +393,7 @@ def _chunk_scan_fwd_kernel(
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :])
|
||||
tl.store(out_x_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim))
|
||||
|
||||
z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||
stride_z_hdim * offs_out_n[None, :])
|
||||
z = tl.load(z_ptrs,
|
||||
@ -433,7 +402,7 @@ def _chunk_scan_fwd_kernel(
|
||||
other=0.0).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :] * stride_out_hdim)
|
||||
tl.store(out_ptrs,
|
||||
@ -449,126 +418,110 @@ def _chunk_scan_fwd(
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
out=None,
|
||||
):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = C.shape
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
assert out.shape == x.shape
|
||||
|
||||
if z is not None:
|
||||
out_x = torch.empty_like(x)
|
||||
assert out_x.stride() == out.stride()
|
||||
else:
|
||||
out_x = None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
headdim, META['BLOCK_SIZE_N']), batch * nchunks
|
||||
headdim, META['BLOCK_SIZE_N']), nchunks
|
||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
|
||||
z.stride(3)) if z is not None else (0, 0, 0, 0))
|
||||
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb,
|
||||
x,
|
||||
z,
|
||||
out,
|
||||
out_x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
C,
|
||||
states,
|
||||
D,
|
||||
initial_states,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
len(chunk_indices) if chunk_indices is not None else 0,
|
||||
chunk_size,
|
||||
headdim,
|
||||
dstate,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads // ngroups,
|
||||
cb.stride(0),
|
||||
cb.stride(1),
|
||||
cb.stride(2),
|
||||
cb.stride(3),
|
||||
cb.stride(4),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
z_strides[3],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
out.stride(3),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else
|
||||
(0, 0)),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
C.stride(3),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
states.stride(4),
|
||||
*((initial_states.stride(0), initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3)) if initial_states is not None else
|
||||
(0, 0, 0, 0)),
|
||||
D.stride(0) if D is not None else 0,
|
||||
True,
|
||||
D is not None,
|
||||
D.dim() == 2 if D is not None else True,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
cb_ptr=cb,
|
||||
x_ptr=x,
|
||||
z_ptr=z,
|
||||
out_ptr=out,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
C_ptr=C,
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
chunk_indices_ptr=chunk_indices,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_cb_chunk=cb.stride(0),
|
||||
stride_cb_head=cb.stride(1),
|
||||
stride_cb_csize_m=cb.stride(2),
|
||||
stride_cb_csize_k=cb.stride(3),
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_z_seqlen=z_strides[0],
|
||||
stride_z_head=z_strides[1],
|
||||
stride_z_hdim=z_strides[2],
|
||||
stride_out_seqlen=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_hdim=out.stride(2),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
stride_D_head=D.stride(0) if D is not None else 0,
|
||||
IS_CAUSAL=True,
|
||||
HAS_D=D is not None,
|
||||
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||
HAS_Z=z is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out_x
|
||||
return
|
||||
|
||||
@ -35,41 +35,35 @@ def _chunk_cumsum_fwd_kernel(
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
# Matrix dimension
|
||||
batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
chunk_size,
|
||||
dt_min,
|
||||
dt_max,
|
||||
nheads: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
dt_min: tl.constexpr,
|
||||
dt_max: tl.constexpr,
|
||||
# Strides
|
||||
stride_dt_batch,
|
||||
stride_dt_seqlen,
|
||||
stride_dt_head,
|
||||
stride_A_head,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_out_batch,
|
||||
stride_dt_out_chunk,
|
||||
stride_dt_out_head,
|
||||
stride_dt_out_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_dt_seqlen: tl.int64,
|
||||
stride_dt_head: tl.constexpr,
|
||||
stride_A_head: tl.constexpr,
|
||||
stride_dt_bias_head: tl.constexpr,
|
||||
stride_dt_out_head: tl.int64,
|
||||
stride_dt_out_chunk: tl.int64,
|
||||
stride_dt_out_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=0)
|
||||
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
@ -93,9 +87,8 @@ def _chunk_cumsum_fwd_kernel(
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
# As of Triton 2.2.0, tl.clamp is not available yet
|
||||
# dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
||||
0.0)
|
||||
@ -197,56 +190,46 @@ def _chunk_state_fwd_kernel(
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
hdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
batch,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_b_batch,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_b_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_dt_batch,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -259,13 +242,11 @@ def _chunk_state_fwd_kernel(
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr +
|
||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_last = tl.load(seq_idx_ptr +
|
||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
seq_idx_last = tl.load(seq_idx_ptr +
|
||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
@ -280,29 +261,28 @@ def _chunk_state_fwd_kernel(
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=-1)
|
||||
|
||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=-1)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
else:
|
||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
|
||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
@ -400,36 +380,35 @@ def _chunk_state_varlen_kernel(
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_b_dstate,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_chunk_states_chunk,
|
||||
stride_chunk_states_head,
|
||||
stride_chunk_states_hdim,
|
||||
stride_chunk_states_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_init_states_batch,
|
||||
stride_init_states_head,
|
||||
stride_init_states_hdim,
|
||||
stride_init_states_dstate,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
@ -558,52 +537,47 @@ def _chunk_cumsum_fwd(dt,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
batch, seqlen, nheads = dt.shape
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads, )
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, )
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
dt_out = torch.empty(batch,
|
||||
nheads,
|
||||
dt_out = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
dA_cumsum = torch.empty(batch,
|
||||
nheads,
|
||||
dA_cumsum = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (batch, nchunks,
|
||||
grid_chunk_cs = lambda META: (nchunks,
|
||||
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt,
|
||||
A,
|
||||
dt_bias,
|
||||
dt_out,
|
||||
dA_cumsum,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
chunk_size,
|
||||
dt_limit[0],
|
||||
dt_limit[1],
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
A.stride(0),
|
||||
dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
dt_out.stride(0),
|
||||
dt_out.stride(2),
|
||||
dt_out.stride(1),
|
||||
dt_out.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
dt_softplus,
|
||||
dt_ptr=dt,
|
||||
A_ptr=A,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
dt_min=dt_limit[0],
|
||||
dt_max=dt_limit[1],
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0)
|
||||
if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
DT_SOFTPLUS=dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
@ -617,63 +591,57 @@ def _chunk_state_fwd(B,
|
||||
seq_idx=None,
|
||||
states=None,
|
||||
states_in_fp32=True):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty((batch, nchunks, nheads, headdim, dstate),
|
||||
states = torch.empty((nchunks, nheads, headdim, dstate),
|
||||
device=x.device,
|
||||
dtype=states_dtype)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x,
|
||||
B,
|
||||
states,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
headdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads // ngroups,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
B.stride(-1),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
states.stride(4),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0),
|
||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
)
|
||||
return states
|
||||
|
||||
@ -705,46 +673,52 @@ def chunk_state_varlen(B,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x,
|
||||
B,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
chunk_states,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states,
|
||||
headdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
total_seqlen,
|
||||
nheads // ngroups,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
chunk_states.stride(0),
|
||||
chunk_states.stride(1),
|
||||
chunk_states.stride(2),
|
||||
chunk_states.stride(3),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
*((initial_states.stride(0), initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3)) if initial_states is not None else
|
||||
(0, 0, 0, 0)),
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None)
|
||||
return states
|
||||
|
||||
@ -31,6 +31,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
@ -41,14 +42,13 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
out=None):
|
||||
state_dtype=None):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
@ -56,25 +56,24 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(
|
||||
1) != 1: # Either M or K dimension should be contiguous
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if z is not None and z.stride(-1) != 1 and z.stride(
|
||||
1) != 1: # Either M or K dimension should be contiguous
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
if cu_seqlens is None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
else:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
|
||||
headdim, dstate)
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
|
||||
dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
@ -114,18 +113,16 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states, final_states = _state_passing_fwd(
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum,
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else None,
|
||||
if initial_states is not None else
|
||||
None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states])
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
@ -144,87 +141,88 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
out_x = _chunk_scan_fwd(
|
||||
_chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
out=out,
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
return out_x, dt, dA_cumsum, states, final_states
|
||||
else:
|
||||
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
||||
varlen_states = chunk_state_varlen(
|
||||
B.squeeze(0),
|
||||
x.squeeze(0),
|
||||
dt.squeeze(0),
|
||||
dA_cumsum.squeeze(0),
|
||||
cu_seqlens,
|
||||
states.squeeze(0),
|
||||
initial_states=initial_states,
|
||||
)
|
||||
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
|
||||
varlen_states = chunk_state_varlen(
|
||||
B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined(x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None,
|
||||
return_final_states=False,
|
||||
return_varlen_states=False,
|
||||
state_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
seq_idx: (batch, seqlen)
|
||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: Preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
"""
|
||||
|
||||
if not return_varlen_states:
|
||||
cu_seqlens = None
|
||||
else:
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
||||
out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (seqlen, nheads, headdim)
|
||||
dt: (seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
seq_idx: (seqlen)
|
||||
cu_seqlens: (batch + 1)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
Return:
|
||||
varlen_states: (batch, nheads, headdim, dstate)
|
||||
"""
|
||||
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||
assert seq_idx is not None
|
||||
|
||||
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
@ -235,14 +233,6 @@ def mamba_chunk_scan_combined(x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
out=out,
|
||||
state_dtype=state_dtype)
|
||||
if not return_varlen_states:
|
||||
if not return_final_states:
|
||||
return
|
||||
else:
|
||||
return final_states
|
||||
else:
|
||||
varlen_states = rest[0]
|
||||
return (varlen_states) if not return_final_states else (final_states,
|
||||
varlen_states)
|
||||
|
||||
return varlen_states
|
||||
|
||||
@ -27,64 +27,46 @@ def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
final_states_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim,
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_dim,
|
||||
stride_out_batch,
|
||||
stride_out_chunk,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
stride_final_states_batch,
|
||||
stride_final_states_head,
|
||||
stride_final_states_dim,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_initstates_batch,
|
||||
stride_initstates_head,
|
||||
stride_initstates_dim,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_dim: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_dim: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
IS_CONT_BATCHED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
|
||||
chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||
1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_h * stride_initstates_head
|
||||
if not IS_CONT_BATCHED:
|
||||
initstates_ptr += pid_b * stride_initstates_batch
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
||||
|
||||
# - states will be the past state of the sequence that continues on the current check
|
||||
if not HAS_INITSTATES:
|
||||
@ -101,65 +83,63 @@ def _state_passing_fwd_kernel(
|
||||
out_ptrs += stride_out_chunk
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks):
|
||||
for c in range(nchunks - 1):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale_mask = True
|
||||
if HAS_SEQ_IDX:
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
|
||||
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
||||
if HAS_INITSTATES:
|
||||
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs,
|
||||
mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if HAS_INITSTATES:
|
||||
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
if c < nchunks - 1:
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
else:
|
||||
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
@ -168,81 +148,53 @@ def _state_passing_fwd_kernel(
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_size=None,
|
||||
out_dtype=None,
|
||||
is_cont_batched=False,
|
||||
chunk_offsets=None,
|
||||
):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
if chunk_size is None:
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
else:
|
||||
assert chunk_size == dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if initial_states is not None:
|
||||
if is_cont_batched:
|
||||
# - if cu_seqlens is provided, then the initial states
|
||||
# are used for continuous batching. In which case we
|
||||
# require seq_idx to be provided
|
||||
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
|
||||
# - we also need chunk_offsets to be provided, to account
|
||||
# for computation of dA_cumsum from the start of the
|
||||
# sequence
|
||||
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
|
||||
else:
|
||||
# - this is the regular batching case, where initial
|
||||
# states are used are for each example of the batch.
|
||||
assert initial_states.shape == (batch, nheads, dim)
|
||||
|
||||
if seq_idx is not None:
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((batch, nchunks, nheads, dim),
|
||||
out = torch.empty((nchunks, nheads, dim),
|
||||
device=states.device,
|
||||
dtype=out_dtype)
|
||||
final_states = torch.empty((batch, nheads, dim),
|
||||
device=states.device,
|
||||
dtype=torch.float32)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2))
|
||||
if initial_states is not None else (0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states,
|
||||
out,
|
||||
final_states,
|
||||
dA_cumsum,
|
||||
initial_states,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
len(chunk_offsets) if chunk_offsets is not None else 0,
|
||||
dim,
|
||||
nchunks,
|
||||
seqlen if seq_idx is not None else 0,
|
||||
chunk_size,
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
out.stride(3),
|
||||
final_states.stride(0),
|
||||
final_states.stride(1),
|
||||
final_states.stride(2),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((initial_states.stride(0), initial_states.stride(1),
|
||||
initial_states.stride(2)) if initial_states is not None else
|
||||
(0, 0, 0)),
|
||||
*((seq_idx.stride(0),
|
||||
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_offsets)
|
||||
if chunk_offsets is not None else 0,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_dim=out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
IS_CONT_BATCHED=is_cont_batched,
|
||||
)
|
||||
return out, final_states
|
||||
return out
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -262,6 +262,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
@ -302,9 +303,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@ -356,17 +354,17 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
varlen_state = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt.unsqueeze(0),
|
||||
dt,
|
||||
self.A,
|
||||
B.view(1, num_prefill_tokens, 1, -1),
|
||||
C.view(1, num_prefill_tokens, 1, -1),
|
||||
B.view(num_prefill_tokens, 1, -1),
|
||||
C.view(num_prefill_tokens, 1, -1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=gate_p.view(1, num_prefill_tokens,
|
||||
z=gate_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size, self.head_dim),
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=seq_idx_p,
|
||||
@ -374,11 +372,9 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ class Mamba2AttentionMetadata:
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_p: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
prep_initial_states: bool
|
||||
@ -151,7 +151,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_p = None
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
seq_idx_p = None
|
||||
@ -179,7 +179,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
common_attn_metadata.query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
@ -190,7 +190,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx_p.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
# model forward and reuse them in mamba layers. If not needed,
|
||||
@ -217,7 +216,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
seq_lens=seq_lens,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user