mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[Kernel] Change interface to Mamba selective_state_update for continuous batching (#8039)
This commit is contained in:
parent
b3195bc9e4
commit
db9120cded
@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 7e-2, 7e-2
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 16
|
||||
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dim, dstate, ngroups, has_z, tie_hdim, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 16
|
||||
headdim = 64
|
||||
nheads = dim // headdim
|
||||
|
||||
total_entries = 10 * batch_size
|
||||
state = torch.randn(total_entries,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=itype,
|
||||
device=device)
|
||||
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size,
|
||||
nheads,
|
||||
headdim,
|
||||
device=device,
|
||||
dtype=itype)
|
||||
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
||||
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
||||
D = torch.randn(nheads, headdim, device=device)
|
||||
else:
|
||||
dt = repeat(torch.randn(batch_size, nheads, device=device,
|
||||
dtype=itype),
|
||||
"b h -> b h p",
|
||||
p=headdim)
|
||||
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
|
||||
"h -> h p",
|
||||
p=headdim)
|
||||
A = repeat(-torch.rand(nheads, device=device) - 1.0,
|
||||
"h -> h p n",
|
||||
p=headdim,
|
||||
n=dstate)
|
||||
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
||||
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
assert torch.allclose(state[state_indices, :],
|
||||
state_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -27,6 +28,10 @@ else:
|
||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics({
|
||||
"HAS_STATE_BATCH_INDICES":
|
||||
lambda args: args["state_batch_indices_ptr"] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
@ -177,7 +195,8 @@ def selective_state_update(state,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
@ -211,7 +230,10 @@ def selective_state_update(state,
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
@ -225,6 +247,8 @@ def selective_state_update(state,
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
out = torch.empty_like(x)
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
@ -249,6 +273,7 @@ def selective_state_update(state,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user