mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
3ed1ec4af2
commit
66072b36db
@ -418,7 +418,9 @@ def test_full_cuda_graph(
|
||||
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_fp32_state(
|
||||
@pytest.mark.parametrize("cache_dtype_param",
|
||||
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
|
||||
def test_fp32_cache_state(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -426,6 +428,7 @@ def test_fp32_state(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
cache_dtype_param: str,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
@ -443,13 +446,13 @@ def test_fp32_state(
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
@ -415,6 +415,9 @@ def causal_conv1d_fn(
|
||||
activation = "silu"
|
||||
|
||||
args = None
|
||||
# Store original dtype to cast back at the end
|
||||
original_x_dtype = x.dtype
|
||||
x = x.to(conv_states.dtype)
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
@ -613,7 +616,7 @@ def causal_conv1d_fn(
|
||||
BLOCK_N=256,
|
||||
num_stages=2,
|
||||
)
|
||||
return out
|
||||
return out.to(original_x_dtype)
|
||||
|
||||
|
||||
@triton.jit()
|
||||
@ -973,6 +976,9 @@ def causal_conv1d_update(
|
||||
activation = "silu" if activation is True else None
|
||||
elif activation is not None:
|
||||
assert activation in ["silu", "swish"]
|
||||
|
||||
original_x_dtype = x.dtype
|
||||
x = x.to(conv_state.dtype)
|
||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||
if unsqueeze:
|
||||
# make it (batch, dim, seqlen) with seqlen == 1
|
||||
@ -1081,4 +1087,4 @@ def causal_conv1d_update(
|
||||
)
|
||||
if unsqueeze:
|
||||
out = out.squeeze(-1)
|
||||
return out
|
||||
return out.to(original_x_dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user