mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:35:00 +08:00
[Bugfix] Mamba cache Cuda Graph padding (#6214)
This commit is contained in:
parent
185ad31f37
commit
ddc369fba1
@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||
|
||||
|
||||
@ -32,6 +34,32 @@ def test_models(
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_mamba_cache_cg_padding(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
except RuntimeError:
|
||||
pytest.fail(
|
||||
"Couldn't run batch size which is not equal to a Cuda Graph "
|
||||
"captured batch size. "
|
||||
"Could be related to mamba cache not padded correctly")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_state_cleanup(
|
||||
|
||||
@ -788,12 +788,12 @@ class JambaForCausalLM(nn.Module):
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
batch_size = len(request_ids_to_seq_ids)
|
||||
cg_batch_size = input_buffers['input_ids'].shape[0]
|
||||
(
|
||||
current_mamba_cache,
|
||||
indices,
|
||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
batch_size)
|
||||
cg_batch_size)
|
||||
self.current_indices = indices
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user