mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
Fix KV sharing fast prefill with cudagraph enabled (#28537)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4516d44b7f
commit
9324e10275
@ -4,13 +4,11 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
|
||||
|
||||
# global seed
|
||||
SEED = 42
|
||||
@ -45,28 +43,12 @@ def test_prompts():
|
||||
return prompts
|
||||
|
||||
|
||||
def cleanup(llm: LLM, compilation_config: CompilationConfig):
|
||||
# hacky: below lines are required to free up memory for the next test
|
||||
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
|
||||
# TODO(sarckk): when enforce_eager=False, memory is not freed:
|
||||
# find out why and re-enable test for enforce_eager=False case
|
||||
llm_engine = llm.llm_engine.engine_core.engine_core
|
||||
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
|
||||
del model_runner.model
|
||||
del model_runner.kv_caches
|
||||
del compilation_config.static_forward_context
|
||||
compilation_config.static_forward_context = {}
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
|
||||
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_kv_sharing_fast_prefill(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
kv_sharing_fast_prefill: bool,
|
||||
enforce_eager: bool,
|
||||
test_prompts: list[str],
|
||||
):
|
||||
@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
|
||||
if not enforce_eager
|
||||
else CompilationMode.NONE,
|
||||
)
|
||||
batch_size = 10
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# Make scheduling deterministic for reproducibility
|
||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
)
|
||||
ref_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
kv_sharing_fast_prefill=True,
|
||||
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
|
||||
)
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
optimized_responses = llm.generate(test_prompts, sampling_params)
|
||||
|
||||
cleanup(llm, compilation_config)
|
||||
|
||||
misses = 0
|
||||
|
||||
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
|
||||
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
|
||||
misses += 1
|
||||
|
||||
assert misses == 0
|
||||
|
||||
@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
|
||||
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
("logits_indices_padded", torch.Tensor | None, None),
|
||||
("num_logits_indices", int, 0),
|
||||
]
|
||||
|
||||
|
||||
def subclass_attention_metadata(
|
||||
name_prefix: str,
|
||||
metadata_cls: Any,
|
||||
@ -986,8 +980,8 @@ def subclass_attention_metadata(
|
||||
|
||||
@runtime_checkable
|
||||
class KVSharingFastPrefillMetadata(Protocol):
|
||||
logits_indices_padded: torch.Tensor
|
||||
num_logits_indices: int
|
||||
logits_indices_padded: torch.Tensor | None = None
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
|
||||
def create_fast_prefill_custom_backend(
|
||||
@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend(
|
||||
for _field in fields(metadata.__class__):
|
||||
setattr(self, _field.name, getattr(metadata, _field.name))
|
||||
|
||||
# Set additional fields that will be used in model code
|
||||
assert (
|
||||
common_attn_metadata.logits_indices_padded is not None
|
||||
and common_attn_metadata.num_logits_indices is not None
|
||||
)
|
||||
self.logits_indices_padded = (
|
||||
common_attn_metadata.logits_indices_padded
|
||||
)
|
||||
|
||||
@ -1314,7 +1314,7 @@ class GPUModelRunner(
|
||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||
"""
|
||||
logits_indices_padded = None
|
||||
num_logits_indices = 0
|
||||
num_logits_indices = None
|
||||
if logits_indices is not None:
|
||||
num_logits_indices = logits_indices.size(0)
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user