[Bugfix] Fix block table for seqs that have prefix cache hits (#7018)

This commit is contained in:
Zach Zheng 2024-08-02 22:38:15 -07:00 committed by GitHub
parent 0c25435daa
commit fb2c1c86c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 3 deletions

View File

@ -6,10 +6,17 @@ from typing import List
import pytest
from tests.kernels.utils import override_backend_env_variable
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device
from ..models.utils import check_outputs_equal
MODELS = [
"facebook/opt-125m",
]
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ):
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
max_tokens: int,
cached_position: int,
use_v2_block_manager: bool,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable(monkeypatch, backend)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
cached_prompt = example_prompts[cached_position]
with vllm_runner(
model,
dtype=dtype,
enable_prefix_caching=True,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
self.input_builder = input_builder
self.runner = input_builder.runner
@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
@ -281,9 +282,14 @@ class FlashAttentionMetadataBuilder(
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1