[Bugfix] Fix hybrid model tests (#17182)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-26 06:14:37 +08:00 committed by GitHub
parent 48cb2109b6
commit 43faa0461a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 159 additions and 535 deletions

View File

@ -531,7 +531,10 @@ class HfRunner:
for _, hidden_state in enumerate(hidden_states): for _, hidden_state in enumerate(hidden_states):
last_hidden_states = hidden_state[-1][0] last_hidden_states = hidden_state[-1][0]
logits = torch.matmul( logits = torch.matmul(
last_hidden_states.to(output_embeddings.weight.device), last_hidden_states.to(
device=output_embeddings.weight.device,
dtype=output_embeddings.weight.dtype,
),
output_embeddings.weight.t(), output_embeddings.weight.t(),
) )
if getattr(output_embeddings, "bias", None) is not None: if getattr(output_embeddings, "bias", None) is not None:

View File

@ -6,71 +6,84 @@ from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_logprobs_close, check_outputs_equal
# This test is for the hybrid models # NOTE: The first model in each list is taken as the primary model,
MODELS = [ # meaning that it will be used in all tests in this file
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct", # The rest of the models will only be tested by test_models
"pfnet/plamo-2-1b"
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1",
] ]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] HYBRID_MODELS = [
# Note: Running Plamo2 in transformers implementation requires to install "ai21labs/Jamba-tiny-dev",
# causal-conv1d package, which is not listed as a test dependency as it's # NOTE: Running Plamo2 in transformers implementation requires to install
# not compatible with pip-compile. # causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"ibm-ai-platform/Bamba-9B",
]
# Avoid OOM
MAX_NUM_SEQS = 4
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int,
) -> None: ) -> None:
# numeric error produces different generation with hf_runner(model) as hf_model:
if "Bamba" in model: hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts.pop(3) example_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype) as hf_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model: check_logprobs_close(
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
for i in range(len(example_prompts)): name_0="hf",
hf_output_ids, hf_output_str = hf_outputs[i] name_1="vllm",
vllm_output_ids, vllm_output_str = vllm_outputs[i] )
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("num_logprobs", [5])
def test_batching( def test_batching(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int,
) -> None: ) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = [] for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts: for prompt in example_prompts:
for_loop_outputs.append( single_output, = vllm_model.generate_greedy_logprobs([prompt],
vllm_model.generate_greedy([prompt], max_tokens)[0]) max_tokens,
num_logprobs)
for_loop_outputs.append(single_output)
batched_outputs = vllm_model.generate_greedy(example_prompts, batched_outputs = vllm_model.generate_greedy_logprobs(
max_tokens) example_prompts, max_tokens, num_logprobs)
check_outputs_equal( check_logprobs_close(
outputs_0_lst=for_loop_outputs, outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs, outputs_1_lst=batched_outputs,
name_0="for_loop_vllm", name_0="for_loop_vllm",
@ -78,72 +91,35 @@ def test_batching(
) )
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("max_tokens", [10]) @pytest.mark.parametrize("num_logprobs", [5])
def test_mamba_prefill_chunking_with_parallel_sampling( @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
hf_runner, vllm_runner, example_prompts, model: str, dtype: str, def test_chunked_prefill(
max_tokens: int) -> None: vllm_runner,
# Tests prefill chunking in conjunction with n>1, in this case, example_prompts,
# prefill is populated with decoding tokens and we test that it model: str,
# doesn't fail This test might fail if cache is not allocated max_tokens: int,
# correctly for n > 1 decoding steps inside a num_logprobs: int,
# chunked prefill forward pass (where we have both prefills chunked_prefill_token_size: int,
# and decoding together ) ) -> None:
max_num_seqs = chunked_prefill_token_size
if 'plamo-2' in model: max_num_batched_tokens = chunked_prefill_token_size
dtype = "float" # use a different dtype for plamo
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [7])
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if "Jamba" in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif "Bamba" in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)
with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_num_batched_tokens=5, max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=2) as vllm_model: max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts, chunked = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens=max_tokens) max_tokens, num_logprobs)
check_outputs_equal( with vllm_runner(model,
enable_chunked_prefill=False,
max_num_seqs=max_num_seqs) as vllm_model:
non_chunked = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=chunked, outputs_0_lst=chunked,
outputs_1_lst=non_chunked, outputs_1_lst=non_chunked,
name_0="chunked", name_0="chunked",
@ -151,57 +127,51 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
) )
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("max_tokens", [15]) def test_chunked_prefill_with_parallel_sampling(
def test_parallel_sampling(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
"""
Tests chunked prefill in conjunction with n > 1.
In this case, prefill is populated with decoding tokens and
we test that it doesn't fail.
with vllm_runner(model, dtype=dtype) as vllm_model: This test might fail if cache is not allocated correctly for n > 1
for_loop_outputs = [] decoding steps inside a chunked prefill forward pass
for _ in range(10): (where we have both prefill and decode together)
for_loop_outputs.append( """
# using example_prompts index 1 instead of 0 since with 0 the sampling_params = SamplingParams(n=3,
# logprobs get really close and the test doesn't pass temperature=1,
vllm_model.generate_greedy([example_prompts[1]], max_tokens) seed=0,
[0]) max_tokens=max_tokens)
sampling_params = SamplingParams(n=10, with vllm_runner(
temperature=0.001, model,
seed=0, enable_chunked_prefill=True,
max_tokens=max_tokens) # forces prefill chunks with decoding
n_lt_1_outputs = vllm_model.generate([example_prompts[1]], max_num_batched_tokens=MAX_NUM_SEQS * 3,
sampling_params) max_num_seqs=MAX_NUM_SEQS,
token_ids, texts = n_lt_1_outputs[0] ) as vllm_model:
n_lt_1_outputs = [(token_id, text) vllm_model.generate(example_prompts, sampling_params)
for token_id, text in zip(token_ids, texts)]
check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding( def test_mamba_cache_cg_padding(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
# This test is for verifying that mamba cache is padded to CG captured """
# batch size. If it's not, a torch RuntimeError will be raised because This test is for verifying that mamba cache is padded to CG captured
# tensor dimensions aren't compatible batch size. If it's not, a torch RuntimeError will be raised because
tensor dimensions aren't compatible.
"""
vllm_config = EngineArgs(model=model, vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config() trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph( while len(example_prompts) == vllm_config.pad_for_cudagraph(
@ -209,7 +179,7 @@ def test_mamba_cache_cg_padding(
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])
try: try:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError: except RuntimeError:
pytest.fail( pytest.fail(
@ -218,28 +188,24 @@ def test_mamba_cache_cg_padding(
"Could be related to mamba cache not padded correctly") "Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute( def test_models_preemption_recompute(
hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute) """
assert dtype == "float" Tests that outputs are identical with and w/o preemptions (recompute).
"""
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_model.model.llm_engine.scheduler[ scheduler = vllm_model.model.llm_engine.scheduler[0]
0].ENABLE_ARTIFICIAL_PREEMPT = True scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy( preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens) example_prompts, max_tokens)
vllm_model.model.llm_engine.scheduler[ scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
@ -250,40 +216,43 @@ def test_models_preemption_recompute(
) )
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner, vllm_runner,
model: str,
dtype: str,
example_prompts, example_prompts,
model: str,
) -> None: ) -> None:
# This test is for verifying that the hybrid inner state management doesn't """
# collapse in case where the number of incoming requests and This test is for verifying that the hybrid inner state management doesn't
# finished_requests_ids is larger than the maximum mamba block capacity. collapse in case where the number of incoming requests and
# This could generally happen due to the fact that hybrid does support finished_requests_ids is larger than the maximum mamba block capacity.
# statelessness mechanism where it can cleanup new incoming requests in
# a single step. This could generally happen due to the fact that hybrid does support
statelessness mechanism where it can cleanup new incoming requests in
a single step.
"""
try: try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10) vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError: except ValueError:
pytest.fail("Hybrid inner state wasn't cleaned up properly between" pytest.fail("Hybrid inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ") "steps finished requests registered unnecessarily ")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup( def test_state_cleanup(
vllm_runner, vllm_runner,
model: str,
dtype: str,
example_prompts, example_prompts,
model: str,
) -> None: ) -> None:
# This test is for verifying that the Hybrid state is cleaned up between """
# steps, If its not cleaned, an error would be expected. This test is for verifying that the Hybrid state is cleaned up between
steps.
If its not cleaned, an error would be expected.
"""
try: try:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for _ in range(10): for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1) vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError: except ValueError:
@ -291,28 +260,14 @@ def test_state_cleanup(
"could be related to finished_requests_ids") "could be related to finished_requests_ids")
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is verifying that multistep works correctly
#on mamba-like models
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str, def test_multistep_correctness(
max_tokens: int, example_prompts) -> None: vllm_runner,
example_prompts,
model: str,
max_tokens: int,
) -> None:
with vllm_runner(model, num_scheduler_steps=8, with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model: max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy( vllm_outputs_multistep = vllm_model.generate_greedy(
@ -332,18 +287,21 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str,
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
def test_hybrid_distributed_produces_identical_generation( def test_hybrid_distributed_produces_identical_generation(
vllm_runner, model: str, dtype: str, max_tokens: int, vllm_runner,
example_prompts) -> None: example_prompts,
model: str,
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: max_tokens: int,
) -> None:
with vllm_runner(model, tensor_parallel_size=2,
max_num_seqs=2) as vllm_model:
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: with vllm_runner(model, tensor_parallel_size=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)

View File

@ -1,337 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
Run `pytest tests/models/test_mamba.py`.
"""
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1",
]
# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
# Tokenize the input prompt with truncation
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)
# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids,
max_new_tokens=max_tokens,
do_sample=False)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)
outputs.append((generated_ids[0].tolist(), generated_text))
return outputs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
# Set max_num_seqs to keep Codestral from going OOM at fp32
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# Tests chunked prefill in conjunction with n>1. In this case, prefill is
# populated with decoding tokens and we test that it doesn't fail.
# This test might fail if cache is not allocated correctly for n > 1
# decoding steps inside a chunked prefill forward pass (where we have both
# prefill and decode together )
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int,
chunked_prefill_token_size: int) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
non_chunked = generate_greedy(model, example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)
check_outputs_equal(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Numerical differences produce slightly different output for these
if 'state-spaces' in model:
example_prompts.pop(0)
example_prompts.pop(0)
example_prompts.pop(0)
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]
check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum Mamba block capacity.
# This could generally happen due to the fact that Mamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)
with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_outputs_multistep,
outputs_1_lst=vllm_outputs_single_step,
name_0="vllm_outputs_multistep",
name_1="vllm_outputs_single_step",
)