mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[Bugfix][Mamba] - Fix Conv State Kernel FP32 Support (#24883)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
3ed1ec4af2
commit
66072b36db
@ -418,7 +418,9 @@ def test_full_cuda_graph(
|
|||||||
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_fp32_state(
|
@pytest.mark.parametrize("cache_dtype_param",
|
||||||
|
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
|
||||||
|
def test_fp32_cache_state(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -426,6 +428,7 @@ def test_fp32_state(
|
|||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
|
cache_dtype_param: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -443,13 +446,13 @@ def test_fp32_state(
|
|||||||
m.setenv("VLLM_USE_V1", "0")
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
|||||||
@ -415,6 +415,9 @@ def causal_conv1d_fn(
|
|||||||
activation = "silu"
|
activation = "silu"
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
|
# Store original dtype to cast back at the end
|
||||||
|
original_x_dtype = x.dtype
|
||||||
|
x = x.to(conv_states.dtype)
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
cu_seqlen = metadata.cu_seqlen
|
cu_seqlen = metadata.cu_seqlen
|
||||||
@ -613,7 +616,7 @@ def causal_conv1d_fn(
|
|||||||
BLOCK_N=256,
|
BLOCK_N=256,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
)
|
)
|
||||||
return out
|
return out.to(original_x_dtype)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit()
|
@triton.jit()
|
||||||
@ -973,6 +976,9 @@ def causal_conv1d_update(
|
|||||||
activation = "silu" if activation is True else None
|
activation = "silu" if activation is True else None
|
||||||
elif activation is not None:
|
elif activation is not None:
|
||||||
assert activation in ["silu", "swish"]
|
assert activation in ["silu", "swish"]
|
||||||
|
|
||||||
|
original_x_dtype = x.dtype
|
||||||
|
x = x.to(conv_state.dtype)
|
||||||
unsqueeze = query_start_loc is None and x.dim() == 2
|
unsqueeze = query_start_loc is None and x.dim() == 2
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
# make it (batch, dim, seqlen) with seqlen == 1
|
# make it (batch, dim, seqlen) with seqlen == 1
|
||||||
@ -1081,4 +1087,4 @@ def causal_conv1d_update(
|
|||||||
)
|
)
|
||||||
if unsqueeze:
|
if unsqueeze:
|
||||||
out = out.squeeze(-1)
|
out = out.squeeze(-1)
|
||||||
return out
|
return out.to(original_x_dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user