diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index f2c6d1c1fd1a..2778b0c5e567 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -4,13 +4,11 @@ import random import pytest -import torch from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode -from vllm.distributed import cleanup_dist_env_and_memory -from ...utils import fork_new_process_for_each_test +from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts # global seed SEED = 42 @@ -45,28 +43,12 @@ def test_prompts(): return prompts -def cleanup(llm: LLM, compilation_config: CompilationConfig): - # hacky: below lines are required to free up memory for the next test - # when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient - # TODO(sarckk): when enforce_eager=False, memory is not freed: - # find out why and re-enable test for enforce_eager=False case - llm_engine = llm.llm_engine.engine_core.engine_core - model_runner = llm_engine.model_executor.driver_worker.worker.model_runner - del model_runner.model - del model_runner.kv_caches - del compilation_config.static_forward_context - compilation_config.static_forward_context = {} - - del llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - - @fork_new_process_for_each_test -@pytest.mark.parametrize("enforce_eager", [True]) -@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") +@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, + kv_sharing_fast_prefill: bool, enforce_eager: bool, test_prompts: list[str], ): @@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill( if not enforce_eager else CompilationMode.NONE, ) + batch_size = 10 with monkeypatch.context() as m: # Make scheduling deterministic for reproducibility m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - llm = LLM( - model="google/gemma-3n-E2B-it", - enforce_eager=enforce_eager, - compilation_config=compilation_config, - seed=SEED, - ) - ref_responses = llm.generate(test_prompts, sampling_params) - - cleanup(llm, compilation_config) + prompts, answer, indices = prep_prompts(batch_size) llm = LLM( model="google/gemma-3n-E2B-it", enforce_eager=enforce_eager, compilation_config=compilation_config, seed=SEED, - kv_sharing_fast_prefill=True, + kv_sharing_fast_prefill=kv_sharing_fast_prefill, + ) + responses = llm.generate(prompts, sampling_params) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, ) - optimized_responses = llm.generate(test_prompts, sampling_params) - - cleanup(llm, compilation_config) - - misses = 0 - - for ref_response, optimized_response in zip(ref_responses, optimized_responses): - if ref_response.outputs[0].text != optimized_response.outputs[0].text: - misses += 1 - - assert misses == 0 diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fd37a665cf05..578153cda786 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) -KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ("logits_indices_padded", torch.Tensor | None, None), - ("num_logits_indices", int, 0), -] - - def subclass_attention_metadata( name_prefix: str, metadata_cls: Any, @@ -986,8 +980,8 @@ def subclass_attention_metadata( @runtime_checkable class KVSharingFastPrefillMetadata(Protocol): - logits_indices_padded: torch.Tensor - num_logits_indices: int + logits_indices_padded: torch.Tensor | None = None + num_logits_indices: int | None = None def create_fast_prefill_custom_backend( @@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend( for _field in fields(metadata.__class__): setattr(self, _field.name, getattr(metadata, _field.name)) - # Set additional fields that will be used in model code - assert ( - common_attn_metadata.logits_indices_padded is not None - and common_attn_metadata.num_logits_indices is not None - ) self.logits_indices_padded = ( common_attn_metadata.logits_indices_padded ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0f7f3a501f5..341bf58f2da8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1314,7 +1314,7 @@ class GPUModelRunner( :return: tuple[attn_metadata, spec_decode_common_attn_metadata] """ logits_indices_padded = None - num_logits_indices = 0 + num_logits_indices = None if logits_indices is not None: num_logits_indices = logits_indices.size(0) if self.cache_config.kv_sharing_fast_prefill: