mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 02:24:58 +08:00
Fix KV sharing fast prefill with cudagraph enabled (#28537)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4516d44b7f
commit
9324e10275
@ -4,13 +4,11 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig, CompilationMode
|
from vllm.config import CompilationConfig, CompilationMode
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
|
||||||
|
|
||||||
from ...utils import fork_new_process_for_each_test
|
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
|
||||||
|
|
||||||
# global seed
|
# global seed
|
||||||
SEED = 42
|
SEED = 42
|
||||||
@ -45,28 +43,12 @@ def test_prompts():
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
def cleanup(llm: LLM, compilation_config: CompilationConfig):
|
|
||||||
# hacky: below lines are required to free up memory for the next test
|
|
||||||
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
|
|
||||||
# TODO(sarckk): when enforce_eager=False, memory is not freed:
|
|
||||||
# find out why and re-enable test for enforce_eager=False case
|
|
||||||
llm_engine = llm.llm_engine.engine_core.engine_core
|
|
||||||
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
|
|
||||||
del model_runner.model
|
|
||||||
del model_runner.kv_caches
|
|
||||||
del compilation_config.static_forward_context
|
|
||||||
compilation_config.static_forward_context = {}
|
|
||||||
|
|
||||||
del llm
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
cleanup_dist_env_and_memory()
|
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
@pytest.mark.parametrize("enforce_eager", [True])
|
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
|
||||||
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
|
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||||
def test_kv_sharing_fast_prefill(
|
def test_kv_sharing_fast_prefill(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
kv_sharing_fast_prefill: bool,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
test_prompts: list[str],
|
test_prompts: list[str],
|
||||||
):
|
):
|
||||||
@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
|
|||||||
if not enforce_eager
|
if not enforce_eager
|
||||||
else CompilationMode.NONE,
|
else CompilationMode.NONE,
|
||||||
)
|
)
|
||||||
|
batch_size = 10
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
# Make scheduling deterministic for reproducibility
|
# Make scheduling deterministic for reproducibility
|
||||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
|
|
||||||
llm = LLM(
|
prompts, answer, indices = prep_prompts(batch_size)
|
||||||
model="google/gemma-3n-E2B-it",
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
compilation_config=compilation_config,
|
|
||||||
seed=SEED,
|
|
||||||
)
|
|
||||||
ref_responses = llm.generate(test_prompts, sampling_params)
|
|
||||||
|
|
||||||
cleanup(llm, compilation_config)
|
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="google/gemma-3n-E2B-it",
|
model="google/gemma-3n-E2B-it",
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
seed=SEED,
|
seed=SEED,
|
||||||
kv_sharing_fast_prefill=True,
|
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
|
||||||
|
)
|
||||||
|
responses = llm.generate(prompts, sampling_params)
|
||||||
|
check_answers(
|
||||||
|
indices,
|
||||||
|
answer,
|
||||||
|
[response.outputs[0].text for response in responses],
|
||||||
|
accept_rate=1.0,
|
||||||
)
|
)
|
||||||
optimized_responses = llm.generate(test_prompts, sampling_params)
|
|
||||||
|
|
||||||
cleanup(llm, compilation_config)
|
|
||||||
|
|
||||||
misses = 0
|
|
||||||
|
|
||||||
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
|
|
||||||
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
|
|
||||||
misses += 1
|
|
||||||
|
|
||||||
assert misses == 0
|
|
||||||
|
|||||||
@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
|
|||||||
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
|
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
|
||||||
|
|
||||||
|
|
||||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
|
||||||
("logits_indices_padded", torch.Tensor | None, None),
|
|
||||||
("num_logits_indices", int, 0),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_metadata(
|
def subclass_attention_metadata(
|
||||||
name_prefix: str,
|
name_prefix: str,
|
||||||
metadata_cls: Any,
|
metadata_cls: Any,
|
||||||
@ -986,8 +980,8 @@ def subclass_attention_metadata(
|
|||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class KVSharingFastPrefillMetadata(Protocol):
|
class KVSharingFastPrefillMetadata(Protocol):
|
||||||
logits_indices_padded: torch.Tensor
|
logits_indices_padded: torch.Tensor | None = None
|
||||||
num_logits_indices: int
|
num_logits_indices: int | None = None
|
||||||
|
|
||||||
|
|
||||||
def create_fast_prefill_custom_backend(
|
def create_fast_prefill_custom_backend(
|
||||||
@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend(
|
|||||||
for _field in fields(metadata.__class__):
|
for _field in fields(metadata.__class__):
|
||||||
setattr(self, _field.name, getattr(metadata, _field.name))
|
setattr(self, _field.name, getattr(metadata, _field.name))
|
||||||
|
|
||||||
# Set additional fields that will be used in model code
|
|
||||||
assert (
|
|
||||||
common_attn_metadata.logits_indices_padded is not None
|
|
||||||
and common_attn_metadata.num_logits_indices is not None
|
|
||||||
)
|
|
||||||
self.logits_indices_padded = (
|
self.logits_indices_padded = (
|
||||||
common_attn_metadata.logits_indices_padded
|
common_attn_metadata.logits_indices_padded
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1314,7 +1314,7 @@ class GPUModelRunner(
|
|||||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||||
"""
|
"""
|
||||||
logits_indices_padded = None
|
logits_indices_padded = None
|
||||||
num_logits_indices = 0
|
num_logits_indices = None
|
||||||
if logits_indices is not None:
|
if logits_indices is not None:
|
||||||
num_logits_indices = logits_indices.size(0)
|
num_logits_indices = logits_indices.size(0)
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user