Fix KV sharing fast prefill with cudagraph enabled (#28537)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yong Hoon Shin 2025-11-14 01:53:42 -10:00 committed by GitHub
parent 4516d44b7f
commit 9324e10275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 57 deletions

View File

@ -4,13 +4,11 @@
import random
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...utils import fork_new_process_for_each_test
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
# global seed
SEED = 42
@ -45,28 +43,12 @@ def test_prompts():
return prompts
def cleanup(llm: LLM, compilation_config: CompilationConfig):
# hacky: below lines are required to free up memory for the next test
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
# TODO(sarckk): when enforce_eager=False, memory is not freed:
# find out why and re-enable test for enforce_eager=False case
llm_engine = llm.llm_engine.engine_core.engine_core
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
del model_runner.model
del model_runner.kv_caches
del compilation_config.static_forward_context
compilation_config.static_forward_context = {}
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool,
test_prompts: list[str],
):
@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
if not enforce_eager
else CompilationMode.NONE,
)
batch_size = 10
with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
)
ref_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
prompts, answer, indices = prep_prompts(batch_size)
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
kv_sharing_fast_prefill=True,
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
)
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
optimized_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
misses = 0
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
misses += 1
assert misses == 0

View File

@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
("logits_indices_padded", torch.Tensor | None, None),
("num_logits_indices", int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
@ -986,8 +980,8 @@ def subclass_attention_metadata(
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
def create_fast_prefill_custom_backend(
@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend(
for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name))
# Set additional fields that will be used in model code
assert (
common_attn_metadata.logits_indices_padded is not None
and common_attn_metadata.num_logits_indices is not None
)
self.logits_indices_padded = (
common_attn_metadata.logits_indices_padded
)

View File

@ -1314,7 +1314,7 @@ class GPUModelRunner(
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
logits_indices_padded = None
num_logits_indices = 0
num_logits_indices = None
if logits_indices is not None:
num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill: