[Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837)

This commit is contained in:
afeldman-nm 2024-05-29 12:09:13 -04:00 committed by GitHub
parent 594392d27a
commit 4238bc82f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 731 additions and 65 deletions

View File

@ -1,11 +1,13 @@
import pytest import pytest
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.interfaces import AllocStatus from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list from vllm.utils import chunk_list
from ..utils import create_seq_group from ..utils import create_seq_group, create_seq_group_encoder_decoder
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
@ -52,6 +54,156 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
assert can_allocate_result == AllocStatus.LATER assert can_allocate_result == AllocStatus.LATER
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160])
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group_encoder_decoder(block_size: int,
num_seqs_per_group: int,
num_gpu_blocks: int,
watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
)
num_watermark_blocks = int(watermark * num_gpu_blocks)
num_output_blocks_per_seq = 1
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks = num_output_blocks_per_seq
for bdx, num_prompt_blocks in enumerate(
range(1, num_gpu_blocks - num_output_blocks)):
num_cross_blocks_per_seq = num_prompt_blocks
seq_group = create_seq_group_encoder_decoder(
seq_prompt_len=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
request_id=str(bdx))
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
can_allocate_result = block_manager.can_allocate(seq_group)
num_required_blocks = num_prompt_blocks + \
num_output_blocks + \
num_cross_blocks_per_seq
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
assert can_allocate_result == AllocStatus.NEVER
elif num_gpu_blocks >= num_required_blocks:
assert can_allocate_result == AllocStatus.OK
else:
assert can_allocate_result == AllocStatus.LATER
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [16])
@pytest.mark.parametrize("num_seqs_per_group", [1])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
num_seqs_per_group: int,
num_gpu_blocks: int,
watermark: float):
'''
SWA short for Sliding Window Attention.
At time of writing block manager v2 does not support SWA.
However even when SWA is implemented for block manager v2,
there will still most likely be a separate workstream required
to enable SWA for encoder/decoder models.
Therefore this test enforces that one of the following cases
hold true:
1. Block manager v2 does not support SWA at all (true at time of writing)
2. Block manager v2 fails with NotImplementError when SWA is enabled
AND a SequenceGroup with an encoder sequence (i.e. in support of an
encoder/decoder model) is passed into can_allocate() as an argument
The setup for this test is stripped down version of
test_can_allocate_seq_group_encoder_decoder()
'''
with pytest.raises((NotImplementedError, AssertionError)) as exc_info:
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
sliding_window=5 # SWA
)
num_output_blocks_per_seq = 1
num_prompt_blocks = 1
num_output_blocks = num_output_blocks_per_seq
seq_group = create_seq_group_encoder_decoder(
seq_prompt_len=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
request_id="0")
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
block_manager.can_allocate(seq_group)
# Assert that either
# 1. Block manager v2 constructor fails with assertion that sliding window
# is not yet supported (most likely near-term outcome at time of
# writing), or
# 2. can_allocate() fails with NotImplementedError due to combination of
# encoder/decoder and sliding window attention
if isinstance(exc_info.value, NotImplementedError):
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
elif isinstance(exc_info.value, AssertionError):
assert str(exc_info.value) == "Sliding window not yet supported"
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [16])
@pytest.mark.parametrize("num_seqs_per_group", [1])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_encoder_decoder_fails_with_prefix_cache(
block_size: int, num_seqs_per_group: int, num_gpu_blocks: int,
watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
enable_caching=True # Prefix cache
)
num_output_blocks_per_seq = 1
num_prompt_blocks = 1
num_output_blocks = num_output_blocks_per_seq
seq_group = create_seq_group_encoder_decoder(
seq_prompt_len=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
request_id="0")
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
# Assert that either can_allocate() fails with NotImplementedError
# due to combination of encoder/decoder and prefix cache
with pytest.raises(NotImplementedError) as exc_info:
block_manager.can_allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
@pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8])
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])

View File

@ -6,13 +6,15 @@ import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
STR_NOT_IMPL_ENC_DEC_SWA)
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
UncachedBlockAllocator) UncachedBlockAllocator)
from vllm.core.interfaces import AllocStatus from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
from .utils import create_dummy_prompt from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
def test_block_allocator_allocate(): def test_block_allocator_allocate():
@ -73,7 +75,7 @@ def test_allocate():
# Allocate same sequence group to all available gpu blocks. # Allocate same sequence group to all available gpu blocks.
for i in range(num_gpu_blocks): for i in range(num_gpu_blocks):
_, seq_group = create_dummy_prompt(str(i), block_size) _, seq_group = create_dummy_prompt(str(i), block_size)
assert block_manager.can_allocate(seq_group) assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group) block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK assert block_manager.can_allocate(seq_group) != AllocStatus.OK
@ -85,11 +87,107 @@ def test_allocate():
watermark=1 / num_gpu_blocks) watermark=1 / num_gpu_blocks)
for i in range(num_gpu_blocks - 1): for i in range(num_gpu_blocks - 1):
_, seq_group = create_dummy_prompt(str(i), block_size) _, seq_group = create_dummy_prompt(str(i), block_size)
assert block_manager.can_allocate(seq_group) assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group) block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK assert block_manager.can_allocate(seq_group) != AllocStatus.OK
def test_allocate_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_req_per_seq_group = 2
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same sequence group to all available gpu blocks.
for i in range(num_gpu_blocks // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
str(i),
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
# Allocate same sequence group to all available gpu blocks.
# Use watermark to reserve one gpu block.
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=1 / num_gpu_blocks)
for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
str(i),
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
block_manager.allocate(seq_group)
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
def test_allocate_encoder_decoder_fails_with_swa():
# SWA short for sliding window attention
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
sliding_window=5) # swa
# Allocate same sequence group to all available gpu blocks.
_, _, seq_group = create_dummy_prompt_encoder_decoder(
"0",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
# Assert that can_allocate() fails due to SWA
with pytest.raises(NotImplementedError) as exc_info:
block_manager.can_allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
# Assert that allocate() fails due to SWA
with pytest.raises(NotImplementedError) as exc_info:
block_manager.allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
def test_allocate_encoder_decoder_fails_with_prefix_caching():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=True) # Prefix cache
# Allocate same sequence group to all available gpu blocks.
_, _, seq_group = create_dummy_prompt_encoder_decoder(
"0",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
# Assert that can_allocate() fails due to prefix caching
with pytest.raises(NotImplementedError) as exc_info:
block_manager.can_allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
# Assert that allocate() fails due to prefix caching
with pytest.raises(NotImplementedError) as exc_info:
block_manager.allocate(seq_group)
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
def test_append_slot_single_seq(): def test_append_slot_single_seq():
block_size = 4 block_size = 4
num_cpu_blocks = 4 num_cpu_blocks = 4
@ -244,6 +342,62 @@ def test_swap():
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
def test_swap_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
decoder_prompt, encoder_prompt, seq_group = \
create_dummy_prompt_encoder_decoder(
"1",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
decoder_prompt.status = SequenceStatus.WAITING
encoder_prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
decoder_prompt.status = SequenceStatus.RUNNING
decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
# Swap encoder/decoder seq group from GPU -> CPU.
decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert [x[0] for x in mapping] == gpu_blocks
#assert list(mapping.keys()) == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
decoder_prompt.status = SequenceStatus.SWAPPED
# Swap encoder/decoder seq group from CPU -> GPU.
decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
def test_free(): def test_free():
block_size = 4 block_size = 4
num_cpu_blocks = 4 num_cpu_blocks = 4
@ -268,6 +422,41 @@ def test_free():
block_manager.get_block_table(prompt) block_manager.get_block_table(prompt)
def test_free_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
decoder_prompt, encoder_prompt, seq_group = \
create_dummy_prompt_encoder_decoder(
"1",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
block_manager.allocate(seq_group)
# Free allocated seq.
decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
before_blocks = block_manager.get_num_free_gpu_blocks()
block_manager.free(decoder_prompt)
block_manager.free_cross(seq_group)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert after_blocks == before_blocks + prompt_blocks
# Block table for freed encoder & decoder seq's are deleted.
with pytest.raises(KeyError):
block_manager.get_block_table(decoder_prompt)
# Block table for freed encoder & decoder seq's are deleted.
with pytest.raises(KeyError):
block_manager.get_block_table(encoder_prompt)
def test_reset(): def test_reset():
block_size = 4 block_size = 4
num_cpu_blocks = 4 num_cpu_blocks = 4
@ -289,6 +478,31 @@ def test_reset():
assert block_manager.get_num_free_gpu_blocks() == original_blocks assert block_manager.get_num_free_gpu_blocks() == original_blocks
def test_reset_encoder_decoder():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_req_per_seq_group = 2
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same seq group on all available gpu blocks.
original_blocks = block_manager.get_num_free_gpu_blocks()
for i in range(num_gpu_blocks // block_req_per_seq_group):
_, _, seq_group = create_dummy_prompt_encoder_decoder(
f"{i}",
decoder_prompt_length=block_size,
encoder_prompt_length=block_size)
block_manager.allocate(seq_group)
assert block_manager.get_num_free_gpu_blocks() == 0
# Resetting block manager frees all allocated blocks.
block_manager.reset()
assert block_manager.get_num_free_gpu_blocks() == original_blocks
def test_sliding_window_multi_seq(): def test_sliding_window_multi_seq():
""" """
Tests that memory allocation and deallocation is handled Tests that memory allocation and deallocation is handled

View File

@ -39,6 +39,52 @@ def create_dummy_prompt(
return prompt, seq_group return prompt, seq_group
def create_dummy_prompt_encoder_decoder(
request_id: str,
decoder_prompt_length: int,
encoder_prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = decoder_prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
decoder_prompt_tokens = list(range(decoder_prompt_length))
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
decoder_prompt = Sequence(int(request_id),
inputs={
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
encoder_prompt = Sequence(int(request_id),
inputs={
"prompt": encoder_prompt_str,
"prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
arrival_time=time.time(),
lora_request=lora_request,
encoder_seq=encoder_prompt)
return decoder_prompt, encoder_prompt, seq_group
def create_seq_group( def create_seq_group(
seq_prompt_len: int = 1024, seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ), seq_output_lens: Iterable[int] = (128, ),
@ -82,5 +128,56 @@ def create_seq_group(
return seq_group return seq_group
def create_seq_group_encoder_decoder(
seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ),
request_id: str = '0',
seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
assert len(seq_output_lens) > 0
if sampling_params is None:
sampling_params = SamplingParams()
prompt_token_ids = [0] * seq_prompt_len
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)
for i in range(output_len):
seq.append_token_id(
token_id=i,
logprobs={i: Logprob(0.0)},
)
seqs.append(seq)
# Encoder sequence
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)
return SequenceGroup(request_id=request_id,
seqs=seqs,
sampling_params=sampling_params,
arrival_time=time.time(),
encoder_seq=encoder_seq)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size

56
vllm/core/block/utils.py Normal file
View File

@ -0,0 +1,56 @@
"""Block manager utils."""
from vllm.sequence import SequenceGroup
# Exception strings for non-implemented block manager enc/dec scenarios
STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \
"is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \
"is not currently supported."
def _get_block_mgr_sliding_window_attr(block_mgr):
'''
BlockManagerV1 and BlockManagerV2 have slightly different
members related to sliding window attention (SWA). This
function extracts the appropriate member to use for determining
whether SWA is enabled.
Arguments:
* block_mgr: BlockManagerV1 or BlockManagerV2 instance
'''
if hasattr(block_mgr, 'block_sliding_window'):
return block_mgr.block_sliding_window
if hasattr(block_mgr, 'max_block_sliding_window'):
return block_mgr.max_block_sliding_window
raise AttributeError("Block manager instance has neither " + \
"block_sliding_window nor " + \
"max_block_sliding_window attributes.")
def check_no_caching_or_swa_for_blockmgr_encdec(
block_mgr, seq_group: SequenceGroup) -> None:
'''
Enforce that prefix caching & sliding-window attention (SWA)
are currently unsupported *specifically* for encoder/decoder models.
Raises NotImplementedError if unsupported scenario is detected.
Arguments:
* block_mgr: BlockSpaceManager instance
* seq_group: SequenceGroup passed to block_mgr
'''
if seq_group.is_encoder_decoder():
if _get_block_mgr_sliding_window_attr(block_mgr) is not None:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)
if block_mgr.enable_caching:
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE)

View File

@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Tuple from typing import Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
@ -255,14 +256,30 @@ class BlockSpaceManagerV1(BlockSpaceManager):
Device.CPU, block_size, num_cpu_blocks) Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks) check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
self_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \
cross_num_required_blocks
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
self.block_sliding_window) self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
@ -276,11 +293,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None: def _allocate_sequence(self, \
# NOTE: Here we assume that all sequences in the group have the same seq: Sequence, \
# prompt. ref_count: int, \
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = len(seq.logical_token_blocks)
@ -290,21 +306,46 @@ class BlockSpaceManagerV1(BlockSpaceManager):
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window] block = block_table[logical_idx % self.block_sliding_window]
# Set the reference counts of the token blocks. # Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs() block.ref_count = ref_count
elif self.enable_caching: elif not is_encoder_decoder and self.enable_caching:
block = self.gpu_allocator.allocate( block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx), seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx)) seq.num_hashed_tokens_of_block(logical_idx))
else: else:
block = self.gpu_allocator.allocate() block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks. # Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs() block.ref_count = ref_count
block_table.append(block) block_table.append(block)
# Assign the block table for each sequence. return block_table
def allocate(self, seq_group: SequenceGroup) -> None:
is_encoder_decoder = seq_group.is_encoder_decoder()
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
# Allocate decoder sequences
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)
# Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[seq.seq_id] = block_table.copy()
# Allocate encoder sequence
if is_encoder_decoder:
# A SequenceGroup has only a single encoder sequence (at most),
# thus allocate with a ref count of 1
block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
1, is_encoder_decoder)
# Assign the cross-attention block table for the SequenceGroup.
self.cross_block_tables[seq_group.request_id] = block_table
def can_append_slots(self, def can_append_slots(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool: num_lookahead_slots: int = 0) -> bool:
@ -443,13 +484,18 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def _get_physical_blocks( def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by # NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group. # the sequences in the same group.
request_id = seq_group.request_id
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
blocks.update(self.block_tables[seq.seq_id]) blocks.update(self.block_tables[seq.seq_id])
# Cross-attention blocks
if seq_group.is_encoder_decoder():
blocks.update(self.cross_block_tables[request_id])
return list(blocks) return list(blocks)
def can_swap_in(self, def can_swap_in(self,
@ -457,8 +503,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_lookahead_slots: int = 0) -> AllocStatus: num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0 assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation" ), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group) blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
if seq_group.is_encoder_decoder():
num_swapped_seqs += 1
num_free_blocks = self.gpu_allocator.get_num_free_blocks() num_free_blocks = self.gpu_allocator.get_num_free_blocks()
# NOTE: Conservatively, we assume that every sequence will allocate # NOTE: Conservatively, we assume that every sequence will allocate
# at least one free block right after the swap-in. # at least one free block right after the swap-in.
@ -471,70 +520,81 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def _swap_block_table(
self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
dest_allocator: BlockAllocatorBase,
mapping: Dict[PhysicalTokenBlock,
PhysicalTokenBlock]) -> BlockTable:
new_block_table = []
for from_block in block_table:
if from_block in mapping:
to_block = mapping[from_block]
to_block.ref_count += 1
else:
to_block = dest_allocator.allocate(
from_block.block_hash, from_block.num_hashed_tokens)
mapping[from_block] = to_block
new_block_table.append(to_block)
# Free the source block swapped in to destination.
src_allocator.free(from_block)
return new_block_table
def swap_in(self, def swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
assert (num_lookahead_slots == 0 assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation" ), "BlockSpaceManagerV1 does not support lookahead allocation"
request_id = seq_group.request_id
# CPU block -> GPU block. # CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping` # dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = [] self.block_tables[seq.seq_id] = \
block_table = self.block_tables[seq.seq_id] self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
for cpu_block in block_table: if seq_group.is_encoder_decoder():
if cpu_block in mapping: self.cross_block_tables[request_id] = \
gpu_block = mapping[cpu_block] self._swap_block_table(self.cross_block_tables[request_id],
gpu_block.ref_count += 1 self.cpu_allocator,
else: self.gpu_allocator,
gpu_block = self.gpu_allocator.allocate( mapping)
cpu_block.block_hash, cpu_block.num_hashed_tokens)
mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU.
self.cpu_allocator.free(cpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = { return [(cpu_block.block_number, gpu_block.block_number)
cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items()]
for cpu_block, gpu_block in mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
def can_swap_out(self, seq_group: SequenceGroup) -> bool: def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group) blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks() return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
request_id = seq_group.request_id
# GPU block -> CPU block. # GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping` # dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = [] self.block_tables[seq.seq_id] = \
block_table = self.block_tables[seq.seq_id] self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
for gpu_block in block_table: if seq_group.is_encoder_decoder():
if gpu_block in mapping: self.cross_block_tables[request_id] = \
cpu_block = mapping[gpu_block] self._swap_block_table(self.cross_block_tables[request_id],
cpu_block.ref_count += 1 self.gpu_allocator,
else: self.cpu_allocator,
cpu_block = self.cpu_allocator.allocate( mapping)
gpu_block.block_hash, gpu_block.num_hashed_tokens)
mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU.
self.gpu_allocator.free(gpu_block)
self.block_tables[seq.seq_id] = new_block_table
block_number_mapping = { return [(cpu_block.block_number, gpu_block.block_number)
gpu_block.block_number: cpu_block.block_number for cpu_block, gpu_block in mapping.items()]
for gpu_block, cpu_block in mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
def _free_block_table(self, block_table: BlockTable) -> None: def _free_block_table(self, block_table: BlockTable) -> None:
# when using a sliding window, each seq will only use up # when using a sliding window, each seq will only use up
@ -559,15 +619,32 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self._free_block_table(block_table) self._free_block_table(block_table)
del self.block_tables[seq.seq_id] del self.block_tables[seq.seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None:
if seq_group.request_id not in self.cross_block_tables:
# Already freed or hasn't ben scheduled yet.
return
block_table = self.cross_block_tables[seq_group.request_id]
self._free_block_table(block_table)
del self.cross_block_tables[seq_group.request_id]
def reset(self) -> None: def reset(self) -> None:
# Free decoder block tables
for block_table in self.block_tables.values(): for block_table in self.block_tables.values():
self._free_block_table(block_table) self._free_block_table(block_table)
self.block_tables.clear() self.block_tables.clear()
# Free cross-attention block tables
for block_table in self.cross_block_tables.values():
self._free_block_table(block_table)
self.cross_block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]: def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table] return [block.block_number for block in block_table]
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
block_table = self.cross_block_tables[seq_group.request_id]
return [block.block_number for block in block_table]
def get_num_free_gpu_blocks(self) -> int: def get_num_free_gpu_blocks(self) -> int:
return self.gpu_allocator.get_num_free_blocks() return self.gpu_allocator.get_num_free_blocks()

View File

@ -5,11 +5,13 @@ from typing import Tuple
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
SeqId = int SeqId = int
EncoderSeqId = str
class BlockSpaceManagerV2(BlockSpaceManager): class BlockSpaceManagerV2(BlockSpaceManager):
@ -94,17 +96,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
) )
self.block_tables: Dict[SeqId, BlockTable] = {} self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks( num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(), seq.get_token_ids(),
block_size=self.block_size, block_size=self.block_size,
) )
if seq_group.is_encoder_decoder():
num_required_blocks += BlockTable.get_num_required_blocks(
seq_group.get_encoder_seq().get_token_ids(),
block_size=self.block_size,
)
if self.max_block_sliding_window is not None: if self.max_block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
self.max_block_sliding_window) self.max_block_sliding_window)
@ -121,7 +132,19 @@ class BlockSpaceManagerV2(BlockSpaceManager):
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
max_block_sliding_window=self.max_block_sliding_window,
)
block_table.allocate(seq.get_token_ids())
return block_table
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# Allocate self-attention block tables for decoder sequences
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs) assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists" & self.block_tables.keys()), "block table already exists"
@ -129,20 +152,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
# prompt. # prompt.
seq = waiting_seqs[0] seq = waiting_seqs[0]
block_table: BlockTable = self._allocate_sequence(seq)
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
max_block_sliding_window=self.max_block_sliding_window,
)
block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table self.block_tables[seq.seq_id] = block_table
# Assign the block table for each sequence. # Assign the block table for each sequence.
for seq in waiting_seqs[1:]: for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork() self.block_tables[seq.seq_id] = block_table.fork()
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
# encoder prompt.
request_id = seq_group.request_id
assert (request_id
not in self.cross_block_tables), \
"block table already exists"
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
if seq_group.is_encoder_decoder():
block_table = self._allocate_sequence(seq_group.get_encoder_seq())
self.cross_block_tables[request_id] = block_table
def can_append_slots(self, seq_group: SequenceGroup, def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool: num_lookahead_slots: int) -> bool:
"""Determine if there is enough space in the GPU KV cache to continue """Determine if there is enough space in the GPU KV cache to continue
@ -197,12 +229,27 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.block_tables[seq.seq_id].free() self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id] del self.block_tables[seq.seq_id]
def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id
if request_id not in self.cross_block_tables:
# Already freed or hasn't been scheduled yet.
return
self.cross_block_tables[request_id].free()
del self.cross_block_tables[request_id]
def get_block_table(self, seq: Sequence) -> List[int]: def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids) assert all(b is not None for b in block_ids)
return block_ids # type: ignore return block_ids # type: ignore
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
request_id = seq_group.request_id
assert request_id in self.cross_block_tables
block_ids = self.cross_block_tables[request_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed # Update the last accessed time of all the blocks accessed
# in this step. # in this step.

View File

@ -430,6 +430,8 @@ class SequenceGroup:
for an embedding model. for an embedding model.
pooling_params: The pooling parameters used to generate the pooling pooling_params: The pooling parameters used to generate the pooling
for an embedding model. for an embedding model.
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
""" """
def __init__( def __init__(
@ -441,6 +443,7 @@ class SequenceGroup:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -455,6 +458,7 @@ class SequenceGroup:
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.embeddings = embeddings self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.encoder_seq = encoder_seq
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
@ -538,6 +542,12 @@ class SequenceGroup:
seq for seq in self.seqs_dict.values() if seq.status == status seq for seq in self.seqs_dict.values() if seq.status == status
] ]
def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq
def get_unfinished_seqs(self) -> List[Sequence]: def get_unfinished_seqs(self) -> List[Sequence]:
return [ return [
seq for seq in self.seqs_dict.values() if not seq.is_finished() seq for seq in self.seqs_dict.values() if not seq.is_finished()
@ -621,6 +631,15 @@ class SequenceGroupMetadata:
used in prefix caching. used in prefix caching.
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
cross_block_table: Optional cross-attention block table associated
with the encoder prompt
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
""" """
def __init__( def __init__(
@ -637,6 +656,8 @@ class SequenceGroupMetadata:
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None, state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
@ -648,6 +669,8 @@ class SequenceGroupMetadata:
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size self._token_chunk_size = token_chunk_size
self.do_sample = do_sample self.do_sample = do_sample