mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 00:07:00 +08:00
[V1][Mamba] - Enable V1 by default for Mamba Models (#23650)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
8bf6266a17
commit
853c371fc3
@ -100,21 +100,19 @@ def test_models(
|
|||||||
else:
|
else:
|
||||||
hf_outputs = None
|
hf_outputs = None
|
||||||
|
|
||||||
if model not in V0_UNSUPPORTED_MODELS:
|
with monkeypatch.context() as m:
|
||||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
if model not in V0_UNSUPPORTED_MODELS:
|
||||||
example_prompts, max_tokens, num_logprobs)
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
else:
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
vllm_v0_outputs = None
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
else:
|
||||||
|
vllm_v0_outputs = None
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
if model in V1_SUPPORTED_MODELS:
|
||||||
with monkeypatch.context() as m:
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
with vllm_runner(model,
|
example_prompts, max_tokens, num_logprobs)
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
|
||||||
enable_prefix_caching=False) as vllm_model:
|
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
else:
|
else:
|
||||||
vllm_v1_outputs = None
|
vllm_v1_outputs = None
|
||||||
|
|
||||||
@ -137,7 +135,7 @@ def test_models(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_batching(
|
def test_batching(
|
||||||
@ -147,10 +145,6 @@ def test_batching(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model in V0_UNSUPPORTED_MODELS:
|
|
||||||
pytest.skip(
|
|
||||||
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
model_info.check_available_online(on_fail="skip")
|
model_info.check_available_online(on_fail="skip")
|
||||||
@ -188,29 +182,32 @@ def test_chunked_prefill(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
chunked_prefill_token_size: int,
|
chunked_prefill_token_size: int,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
max_num_seqs = chunked_prefill_token_size
|
max_num_seqs = chunked_prefill_token_size
|
||||||
max_num_batched_tokens = chunked_prefill_token_size
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
with vllm_runner(model,
|
with monkeypatch.context() as m:
|
||||||
enable_chunked_prefill=True,
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
with vllm_runner(model,
|
||||||
max_num_seqs=max_num_seqs) as vllm_model:
|
enable_chunked_prefill=True,
|
||||||
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
max_tokens, num_logprobs)
|
max_num_seqs=max_num_seqs) as vllm_model:
|
||||||
|
chunked = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
enable_chunked_prefill=False,
|
enable_chunked_prefill=False,
|
||||||
max_num_seqs=max_num_seqs) as vllm_model:
|
max_num_seqs=max_num_seqs) as vllm_model:
|
||||||
non_chunked = vllm_model.generate_greedy_logprobs(
|
non_chunked = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=chunked,
|
outputs_0_lst=chunked,
|
||||||
outputs_1_lst=non_chunked,
|
outputs_1_lst=non_chunked,
|
||||||
name_0="chunked",
|
name_0="chunked",
|
||||||
name_1="non_chunked",
|
name_1="non_chunked",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||||
@ -281,25 +278,29 @@ def test_models_preemption_recompute(
|
|||||||
example_prompts,
|
example_prompts,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Tests that outputs are identical with and w/o preemptions (recompute).
|
Tests that outputs are identical with and w/o preemptions (recompute).
|
||||||
"""
|
"""
|
||||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
with monkeypatch.context() as m:
|
||||||
scheduler = vllm_model.llm.llm_engine.scheduler[0]
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
scheduler = vllm_model.llm.llm_engine.scheduler[0]
|
||||||
example_prompts, max_tokens)
|
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
|
||||||
|
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
|
||||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
|
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=preempt_vllm_outputs,
|
outputs_0_lst=preempt_vllm_outputs,
|
||||||
outputs_1_lst=vllm_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
name_0="vllm_preepmtions",
|
name_0="vllm_preepmtions",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||||
@ -402,24 +403,18 @@ def test_full_cuda_graph(
|
|||||||
else:
|
else:
|
||||||
hf_outputs = None
|
hf_outputs = None
|
||||||
|
|
||||||
if model not in V0_UNSUPPORTED_MODELS:
|
|
||||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
|
||||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
else:
|
|
||||||
vllm_v0_outputs = None
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
if model in HYBRID_MODELS:
|
if model not in V0_UNSUPPORTED_MODELS:
|
||||||
# required due to reorder_batch behaviour
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
with vllm_runner(model,
|
example_prompts, max_tokens, num_logprobs)
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
else:
|
||||||
compilation_config={'full_cuda_graph': True},
|
vllm_v0_outputs = None
|
||||||
enable_prefix_caching=False) as vllm_model:
|
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
example_prompts, max_tokens, num_logprobs)
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if hf_outputs is not None and vllm_v0_outputs is not None:
|
if hf_outputs is not None and vllm_v0_outputs is not None:
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
@ -466,24 +461,20 @@ def test_fp32_state(
|
|||||||
else:
|
else:
|
||||||
hf_outputs = None
|
hf_outputs = None
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "0")
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||||
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
|
||||||
if model in HYBRID_MODELS:
|
|
||||||
# required due to reorder_batch behaviour
|
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
|
||||||
with vllm_runner(model,
|
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
|
||||||
mamba_ssm_cache_dtype="float32",
|
|
||||||
enable_prefix_caching=False) as vllm_model:
|
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
|
||||||
example_prompts, max_tokens, num_logprobs)
|
|
||||||
|
|
||||||
if hf_outputs is not None:
|
if hf_outputs is not None:
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
|
|||||||
@ -1463,11 +1463,6 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# V1 mamba models are unoptimized.
|
|
||||||
if model_config.has_inner_state and _warn_or_fallback(
|
|
||||||
feature_name="Mamba"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No Concurrent Partial Prefills so far.
|
# No Concurrent Partial Prefills so far.
|
||||||
if (self.max_num_partial_prefills
|
if (self.max_num_partial_prefills
|
||||||
!= SchedulerConfig.max_num_partial_prefills
|
!= SchedulerConfig.max_num_partial_prefills
|
||||||
|
|||||||
@ -417,4 +417,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||||
"MambaForCausalLM": MambaModelConfig,
|
"MambaForCausalLM": MambaModelConfig,
|
||||||
"Mamba2ForCausalLM": MambaModelConfig,
|
"Mamba2ForCausalLM": MambaModelConfig,
|
||||||
|
"FalconMambaForCausalLM": MambaModelConfig,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user