Fix KV sharing fast prefill with cudagraph enabled (#28537)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yong Hoon Shin 2025-11-14 01:53:42 -10:00 committed by GitHub
parent 4516d44b7f
commit 9324e10275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 57 deletions

View File

@ -4,13 +4,11 @@
import random import random
import pytest import pytest
import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...utils import fork_new_process_for_each_test from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
# global seed # global seed
SEED = 42 SEED = 42
@ -45,28 +43,12 @@ def test_prompts():
return prompts return prompts
def cleanup(llm: LLM, compilation_config: CompilationConfig):
# hacky: below lines are required to free up memory for the next test
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
# TODO(sarckk): when enforce_eager=False, memory is not freed:
# find out why and re-enable test for enforce_eager=False case
llm_engine = llm.llm_engine.engine_core.engine_core
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
del model_runner.model
del model_runner.kv_caches
del compilation_config.static_forward_context
compilation_config.static_forward_context = {}
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") @pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill( def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool, enforce_eager: bool,
test_prompts: list[str], test_prompts: list[str],
): ):
@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
if not enforce_eager if not enforce_eager
else CompilationMode.NONE, else CompilationMode.NONE,
) )
batch_size = 10
with monkeypatch.context() as m: with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility # Make scheduling deterministic for reproducibility
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM( prompts, answer, indices = prep_prompts(batch_size)
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
)
ref_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
llm = LLM( llm = LLM(
model="google/gemma-3n-E2B-it", model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
compilation_config=compilation_config, compilation_config=compilation_config,
seed=SEED, seed=SEED,
kv_sharing_fast_prefill=True, kv_sharing_fast_prefill=kv_sharing_fast_prefill,
)
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
) )
optimized_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
misses = 0
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
misses += 1
assert misses == 0

View File

@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
("logits_indices_padded", torch.Tensor | None, None),
("num_logits_indices", int, 0),
]
def subclass_attention_metadata( def subclass_attention_metadata(
name_prefix: str, name_prefix: str,
metadata_cls: Any, metadata_cls: Any,
@ -986,8 +980,8 @@ def subclass_attention_metadata(
@runtime_checkable @runtime_checkable
class KVSharingFastPrefillMetadata(Protocol): class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int num_logits_indices: int | None = None
def create_fast_prefill_custom_backend( def create_fast_prefill_custom_backend(
@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend(
for _field in fields(metadata.__class__): for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name)) setattr(self, _field.name, getattr(metadata, _field.name))
# Set additional fields that will be used in model code
assert (
common_attn_metadata.logits_indices_padded is not None
and common_attn_metadata.num_logits_indices is not None
)
self.logits_indices_padded = ( self.logits_indices_padded = (
common_attn_metadata.logits_indices_padded common_attn_metadata.logits_indices_padded
) )

View File

@ -1314,7 +1314,7 @@ class GPUModelRunner(
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
logits_indices_padded = None logits_indices_padded = None
num_logits_indices = 0 num_logits_indices = None
if logits_indices is not None: if logits_indices is not None:
num_logits_indices = logits_indices.size(0) num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill: if self.cache_config.kv_sharing_fast_prefill: