mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:35:52 +08:00
[Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837)
This commit is contained in:
parent
594392d27a
commit
4238bc82f2
@ -1,11 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||||
|
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||||
from vllm.core.interfaces import AllocStatus
|
from vllm.core.interfaces import AllocStatus
|
||||||
from vllm.sequence import Logprob, SequenceStatus
|
from vllm.sequence import Logprob, SequenceStatus
|
||||||
from vllm.utils import chunk_list
|
from vllm.utils import chunk_list
|
||||||
|
|
||||||
from ..utils import create_seq_group
|
from ..utils import create_seq_group, create_seq_group_encoder_decoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block_size", [16])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
@ -52,6 +54,156 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
|||||||
assert can_allocate_result == AllocStatus.LATER
|
assert can_allocate_result == AllocStatus.LATER
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160])
|
||||||
|
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
|
||||||
|
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||||
|
def test_can_allocate_seq_group_encoder_decoder(block_size: int,
|
||||||
|
num_seqs_per_group: int,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
watermark: float):
|
||||||
|
block_manager = BlockSpaceManagerV2(
|
||||||
|
block_size=block_size,
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=1024,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
num_watermark_blocks = int(watermark * num_gpu_blocks)
|
||||||
|
|
||||||
|
num_output_blocks_per_seq = 1
|
||||||
|
|
||||||
|
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
|
||||||
|
# the current implementation assumes all seqs are new prompts / don't have
|
||||||
|
# different output lens.
|
||||||
|
num_output_blocks = num_output_blocks_per_seq
|
||||||
|
|
||||||
|
for bdx, num_prompt_blocks in enumerate(
|
||||||
|
range(1, num_gpu_blocks - num_output_blocks)):
|
||||||
|
num_cross_blocks_per_seq = num_prompt_blocks
|
||||||
|
|
||||||
|
seq_group = create_seq_group_encoder_decoder(
|
||||||
|
seq_prompt_len=block_size * num_prompt_blocks,
|
||||||
|
seq_output_lens=[
|
||||||
|
block_size * num_output_blocks_per_seq
|
||||||
|
for _ in range(num_seqs_per_group)
|
||||||
|
],
|
||||||
|
request_id=str(bdx))
|
||||||
|
|
||||||
|
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||||
|
|
||||||
|
can_allocate_result = block_manager.can_allocate(seq_group)
|
||||||
|
|
||||||
|
num_required_blocks = num_prompt_blocks + \
|
||||||
|
num_output_blocks + \
|
||||||
|
num_cross_blocks_per_seq
|
||||||
|
|
||||||
|
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
|
||||||
|
assert can_allocate_result == AllocStatus.NEVER
|
||||||
|
elif num_gpu_blocks >= num_required_blocks:
|
||||||
|
assert can_allocate_result == AllocStatus.OK
|
||||||
|
else:
|
||||||
|
assert can_allocate_result == AllocStatus.LATER
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("num_gpu_blocks", [16])
|
||||||
|
@pytest.mark.parametrize("num_seqs_per_group", [1])
|
||||||
|
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||||
|
def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int,
|
||||||
|
num_seqs_per_group: int,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
watermark: float):
|
||||||
|
'''
|
||||||
|
SWA short for Sliding Window Attention.
|
||||||
|
|
||||||
|
At time of writing block manager v2 does not support SWA.
|
||||||
|
|
||||||
|
However even when SWA is implemented for block manager v2,
|
||||||
|
there will still most likely be a separate workstream required
|
||||||
|
to enable SWA for encoder/decoder models.
|
||||||
|
|
||||||
|
Therefore this test enforces that one of the following cases
|
||||||
|
hold true:
|
||||||
|
1. Block manager v2 does not support SWA at all (true at time of writing)
|
||||||
|
2. Block manager v2 fails with NotImplementError when SWA is enabled
|
||||||
|
AND a SequenceGroup with an encoder sequence (i.e. in support of an
|
||||||
|
encoder/decoder model) is passed into can_allocate() as an argument
|
||||||
|
|
||||||
|
The setup for this test is stripped down version of
|
||||||
|
test_can_allocate_seq_group_encoder_decoder()
|
||||||
|
'''
|
||||||
|
|
||||||
|
with pytest.raises((NotImplementedError, AssertionError)) as exc_info:
|
||||||
|
block_manager = BlockSpaceManagerV2(
|
||||||
|
block_size=block_size,
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=1024,
|
||||||
|
watermark=watermark,
|
||||||
|
sliding_window=5 # SWA
|
||||||
|
)
|
||||||
|
|
||||||
|
num_output_blocks_per_seq = 1
|
||||||
|
num_prompt_blocks = 1
|
||||||
|
num_output_blocks = num_output_blocks_per_seq
|
||||||
|
seq_group = create_seq_group_encoder_decoder(
|
||||||
|
seq_prompt_len=block_size * num_prompt_blocks,
|
||||||
|
seq_output_lens=[
|
||||||
|
block_size * num_output_blocks_per_seq
|
||||||
|
for _ in range(num_seqs_per_group)
|
||||||
|
],
|
||||||
|
request_id="0")
|
||||||
|
|
||||||
|
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||||
|
block_manager.can_allocate(seq_group)
|
||||||
|
|
||||||
|
# Assert that either
|
||||||
|
# 1. Block manager v2 constructor fails with assertion that sliding window
|
||||||
|
# is not yet supported (most likely near-term outcome at time of
|
||||||
|
# writing), or
|
||||||
|
# 2. can_allocate() fails with NotImplementedError due to combination of
|
||||||
|
# encoder/decoder and sliding window attention
|
||||||
|
if isinstance(exc_info.value, NotImplementedError):
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||||
|
elif isinstance(exc_info.value, AssertionError):
|
||||||
|
assert str(exc_info.value) == "Sliding window not yet supported"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("num_gpu_blocks", [16])
|
||||||
|
@pytest.mark.parametrize("num_seqs_per_group", [1])
|
||||||
|
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||||
|
def test_can_allocate_encoder_decoder_fails_with_prefix_cache(
|
||||||
|
block_size: int, num_seqs_per_group: int, num_gpu_blocks: int,
|
||||||
|
watermark: float):
|
||||||
|
|
||||||
|
block_manager = BlockSpaceManagerV2(
|
||||||
|
block_size=block_size,
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=1024,
|
||||||
|
watermark=watermark,
|
||||||
|
enable_caching=True # Prefix cache
|
||||||
|
)
|
||||||
|
|
||||||
|
num_output_blocks_per_seq = 1
|
||||||
|
num_prompt_blocks = 1
|
||||||
|
num_output_blocks = num_output_blocks_per_seq
|
||||||
|
seq_group = create_seq_group_encoder_decoder(
|
||||||
|
seq_prompt_len=block_size * num_prompt_blocks,
|
||||||
|
seq_output_lens=[
|
||||||
|
block_size * num_output_blocks_per_seq
|
||||||
|
for _ in range(num_seqs_per_group)
|
||||||
|
],
|
||||||
|
request_id="0")
|
||||||
|
|
||||||
|
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||||
|
|
||||||
|
# Assert that either can_allocate() fails with NotImplementedError
|
||||||
|
# due to combination of encoder/decoder and prefix cache
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
block_manager.can_allocate(seq_group)
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("block_size", [1, 8])
|
@pytest.mark.parametrize("block_size", [1, 8])
|
||||||
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
|
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
|
||||||
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
|
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
|
||||||
|
|||||||
@ -6,13 +6,15 @@ import pytest
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.block import PhysicalTokenBlock
|
from vllm.block import PhysicalTokenBlock
|
||||||
|
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||||
|
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||||
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
||||||
UncachedBlockAllocator)
|
UncachedBlockAllocator)
|
||||||
from vllm.core.interfaces import AllocStatus
|
from vllm.core.interfaces import AllocStatus
|
||||||
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
|
||||||
from .utils import create_dummy_prompt
|
from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
|
||||||
|
|
||||||
|
|
||||||
def test_block_allocator_allocate():
|
def test_block_allocator_allocate():
|
||||||
@ -73,7 +75,7 @@ def test_allocate():
|
|||||||
# Allocate same sequence group to all available gpu blocks.
|
# Allocate same sequence group to all available gpu blocks.
|
||||||
for i in range(num_gpu_blocks):
|
for i in range(num_gpu_blocks):
|
||||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||||
assert block_manager.can_allocate(seq_group)
|
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||||
block_manager.allocate(seq_group)
|
block_manager.allocate(seq_group)
|
||||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||||
|
|
||||||
@ -85,11 +87,107 @@ def test_allocate():
|
|||||||
watermark=1 / num_gpu_blocks)
|
watermark=1 / num_gpu_blocks)
|
||||||
for i in range(num_gpu_blocks - 1):
|
for i in range(num_gpu_blocks - 1):
|
||||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||||
assert block_manager.can_allocate(seq_group)
|
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||||
block_manager.allocate(seq_group)
|
block_manager.allocate(seq_group)
|
||||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocate_encoder_decoder():
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_req_per_seq_group = 2
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0)
|
||||||
|
|
||||||
|
# Allocate same sequence group to all available gpu blocks.
|
||||||
|
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
str(i),
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||||
|
|
||||||
|
# Allocate same sequence group to all available gpu blocks.
|
||||||
|
# Use watermark to reserve one gpu block.
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=1 / num_gpu_blocks)
|
||||||
|
for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
str(i),
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocate_encoder_decoder_fails_with_swa():
|
||||||
|
# SWA short for sliding window attention
|
||||||
|
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0,
|
||||||
|
sliding_window=5) # swa
|
||||||
|
|
||||||
|
# Allocate same sequence group to all available gpu blocks.
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
"0",
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
|
||||||
|
# Assert that can_allocate() fails due to SWA
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
block_manager.can_allocate(seq_group)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||||
|
|
||||||
|
# Assert that allocate() fails due to SWA
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||||
|
|
||||||
|
|
||||||
|
def test_allocate_encoder_decoder_fails_with_prefix_caching():
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0,
|
||||||
|
enable_caching=True) # Prefix cache
|
||||||
|
|
||||||
|
# Allocate same sequence group to all available gpu blocks.
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
"0",
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
|
||||||
|
# Assert that can_allocate() fails due to prefix caching
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
block_manager.can_allocate(seq_group)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||||
|
|
||||||
|
# Assert that allocate() fails due to prefix caching
|
||||||
|
with pytest.raises(NotImplementedError) as exc_info:
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||||
|
|
||||||
|
|
||||||
def test_append_slot_single_seq():
|
def test_append_slot_single_seq():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
num_cpu_blocks = 4
|
num_cpu_blocks = 4
|
||||||
@ -244,6 +342,62 @@ def test_swap():
|
|||||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_swap_encoder_decoder():
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0)
|
||||||
|
|
||||||
|
decoder_prompt, encoder_prompt, seq_group = \
|
||||||
|
create_dummy_prompt_encoder_decoder(
|
||||||
|
"1",
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
decoder_prompt.status = SequenceStatus.WAITING
|
||||||
|
encoder_prompt.status = SequenceStatus.WAITING
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
|
||||||
|
# Emulate a forward pass by appending a single token.
|
||||||
|
# The block manager then knows how many unprocessed
|
||||||
|
# tokens will be written in the next forward pass.
|
||||||
|
token_id = 0
|
||||||
|
decoder_prompt.status = SequenceStatus.RUNNING
|
||||||
|
decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||||
|
|
||||||
|
# Swap encoder/decoder seq group from GPU -> CPU.
|
||||||
|
decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||||
|
cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||||
|
gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
|
||||||
|
assert block_manager.can_swap_out(seq_group)
|
||||||
|
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
|
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
mapping = block_manager.swap_out(seq_group)
|
||||||
|
assert [x[0] for x in mapping] == gpu_blocks
|
||||||
|
#assert list(mapping.keys()) == gpu_blocks
|
||||||
|
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
|
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||||
|
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||||
|
decoder_prompt.status = SequenceStatus.SWAPPED
|
||||||
|
|
||||||
|
# Swap encoder/decoder seq group from CPU -> GPU.
|
||||||
|
decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||||
|
cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||||
|
cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
|
||||||
|
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||||
|
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
|
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
mapping = block_manager.swap_in(seq_group)
|
||||||
|
assert [x[0] for x in mapping] == cpu_blocks
|
||||||
|
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
|
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||||
|
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||||
|
|
||||||
|
|
||||||
def test_free():
|
def test_free():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
num_cpu_blocks = 4
|
num_cpu_blocks = 4
|
||||||
@ -268,6 +422,41 @@ def test_free():
|
|||||||
block_manager.get_block_table(prompt)
|
block_manager.get_block_table(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_encoder_decoder():
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0)
|
||||||
|
|
||||||
|
decoder_prompt, encoder_prompt, seq_group = \
|
||||||
|
create_dummy_prompt_encoder_decoder(
|
||||||
|
"1",
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
|
||||||
|
# Free allocated seq.
|
||||||
|
decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
|
||||||
|
encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
|
||||||
|
prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
|
||||||
|
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
block_manager.free(decoder_prompt)
|
||||||
|
block_manager.free_cross(seq_group)
|
||||||
|
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
assert after_blocks == before_blocks + prompt_blocks
|
||||||
|
|
||||||
|
# Block table for freed encoder & decoder seq's are deleted.
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
block_manager.get_block_table(decoder_prompt)
|
||||||
|
|
||||||
|
# Block table for freed encoder & decoder seq's are deleted.
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
block_manager.get_block_table(encoder_prompt)
|
||||||
|
|
||||||
|
|
||||||
def test_reset():
|
def test_reset():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
num_cpu_blocks = 4
|
num_cpu_blocks = 4
|
||||||
@ -289,6 +478,31 @@ def test_reset():
|
|||||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_encoder_decoder():
|
||||||
|
block_size = 4
|
||||||
|
num_cpu_blocks = 4
|
||||||
|
num_gpu_blocks = 4
|
||||||
|
block_req_per_seq_group = 2
|
||||||
|
block_manager = BlockSpaceManagerV1(block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
num_gpu_blocks,
|
||||||
|
watermark=0)
|
||||||
|
|
||||||
|
# Allocate same seq group on all available gpu blocks.
|
||||||
|
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
|
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||||
|
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||||
|
f"{i}",
|
||||||
|
decoder_prompt_length=block_size,
|
||||||
|
encoder_prompt_length=block_size)
|
||||||
|
block_manager.allocate(seq_group)
|
||||||
|
assert block_manager.get_num_free_gpu_blocks() == 0
|
||||||
|
|
||||||
|
# Resetting block manager frees all allocated blocks.
|
||||||
|
block_manager.reset()
|
||||||
|
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||||
|
|
||||||
|
|
||||||
def test_sliding_window_multi_seq():
|
def test_sliding_window_multi_seq():
|
||||||
"""
|
"""
|
||||||
Tests that memory allocation and deallocation is handled
|
Tests that memory allocation and deallocation is handled
|
||||||
|
|||||||
@ -39,6 +39,52 @@ def create_dummy_prompt(
|
|||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_prompt_encoder_decoder(
|
||||||
|
request_id: str,
|
||||||
|
decoder_prompt_length: int,
|
||||||
|
encoder_prompt_length: int,
|
||||||
|
block_size: Optional[int] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
use_beam_search: bool = False,
|
||||||
|
best_of: int = 1,
|
||||||
|
) -> Tuple[Sequence, SequenceGroup]:
|
||||||
|
if not block_size:
|
||||||
|
block_size = decoder_prompt_length
|
||||||
|
|
||||||
|
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||||
|
# and prompt "0 ... block_size".
|
||||||
|
decoder_prompt_tokens = list(range(decoder_prompt_length))
|
||||||
|
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
|
||||||
|
|
||||||
|
decoder_prompt = Sequence(int(request_id),
|
||||||
|
inputs={
|
||||||
|
"prompt": decoder_prompt_str,
|
||||||
|
"prompt_token_ids": decoder_prompt_tokens,
|
||||||
|
"multi_modal_data": None,
|
||||||
|
},
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
||||||
|
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
||||||
|
encoder_prompt = Sequence(int(request_id),
|
||||||
|
inputs={
|
||||||
|
"prompt": encoder_prompt_str,
|
||||||
|
"prompt_token_ids": encoder_prompt_tokens,
|
||||||
|
"multi_modal_data": None,
|
||||||
|
},
|
||||||
|
block_size=block_size)
|
||||||
|
seq_group = SequenceGroup(request_id=request_id,
|
||||||
|
seqs=[decoder_prompt],
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
best_of=best_of),
|
||||||
|
arrival_time=time.time(),
|
||||||
|
lora_request=lora_request,
|
||||||
|
encoder_seq=encoder_prompt)
|
||||||
|
|
||||||
|
return decoder_prompt, encoder_prompt, seq_group
|
||||||
|
|
||||||
|
|
||||||
def create_seq_group(
|
def create_seq_group(
|
||||||
seq_prompt_len: int = 1024,
|
seq_prompt_len: int = 1024,
|
||||||
seq_output_lens: Iterable[int] = (128, ),
|
seq_output_lens: Iterable[int] = (128, ),
|
||||||
@ -82,5 +128,56 @@ def create_seq_group(
|
|||||||
return seq_group
|
return seq_group
|
||||||
|
|
||||||
|
|
||||||
|
def create_seq_group_encoder_decoder(
|
||||||
|
seq_prompt_len: int = 1024,
|
||||||
|
seq_output_lens: Iterable[int] = (128, ),
|
||||||
|
request_id: str = '0',
|
||||||
|
seq_id_start: int = 0,
|
||||||
|
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
|
||||||
|
|
||||||
|
assert len(seq_output_lens) > 0
|
||||||
|
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
prompt_token_ids = [0] * seq_prompt_len
|
||||||
|
|
||||||
|
seqs = []
|
||||||
|
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||||
|
seq = Sequence(
|
||||||
|
seq_id=seq_id_start + seq_id_offset,
|
||||||
|
inputs={
|
||||||
|
"prompt": "",
|
||||||
|
"prompt_token_ids": prompt_token_ids,
|
||||||
|
"multi_modal_data": None,
|
||||||
|
},
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(output_len):
|
||||||
|
seq.append_token_id(
|
||||||
|
token_id=i,
|
||||||
|
logprobs={i: Logprob(0.0)},
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# Encoder sequence
|
||||||
|
encoder_seq = Sequence(
|
||||||
|
seq_id=seq_id_start + len(seq_output_lens),
|
||||||
|
inputs={
|
||||||
|
"prompt": "",
|
||||||
|
"prompt_token_ids": prompt_token_ids,
|
||||||
|
"multi_modal_data": None,
|
||||||
|
},
|
||||||
|
block_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SequenceGroup(request_id=request_id,
|
||||||
|
seqs=seqs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
arrival_time=time.time(),
|
||||||
|
encoder_seq=encoder_seq)
|
||||||
|
|
||||||
|
|
||||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||||
return (seq_len + block_size - 1) // block_size
|
return (seq_len + block_size - 1) // block_size
|
||||||
56
vllm/core/block/utils.py
Normal file
56
vllm/core/block/utils.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
"""Block manager utils."""
|
||||||
|
from vllm.sequence import SequenceGroup
|
||||||
|
|
||||||
|
# Exception strings for non-implemented block manager enc/dec scenarios
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||||
|
"Sliding window attention for encoder/decoder models " + \
|
||||||
|
"is not currently supported."
|
||||||
|
|
||||||
|
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
|
||||||
|
"Prefix caching for encoder/decoder models " + \
|
||||||
|
"is not currently supported."
|
||||||
|
|
||||||
|
|
||||||
|
def _get_block_mgr_sliding_window_attr(block_mgr):
|
||||||
|
'''
|
||||||
|
BlockManagerV1 and BlockManagerV2 have slightly different
|
||||||
|
members related to sliding window attention (SWA). This
|
||||||
|
function extracts the appropriate member to use for determining
|
||||||
|
whether SWA is enabled.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* block_mgr: BlockManagerV1 or BlockManagerV2 instance
|
||||||
|
'''
|
||||||
|
|
||||||
|
if hasattr(block_mgr, 'block_sliding_window'):
|
||||||
|
return block_mgr.block_sliding_window
|
||||||
|
if hasattr(block_mgr, 'max_block_sliding_window'):
|
||||||
|
return block_mgr.max_block_sliding_window
|
||||||
|
|
||||||
|
raise AttributeError("Block manager instance has neither " + \
|
||||||
|
"block_sliding_window nor " + \
|
||||||
|
"max_block_sliding_window attributes.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_no_caching_or_swa_for_blockmgr_encdec(
|
||||||
|
block_mgr, seq_group: SequenceGroup) -> None:
|
||||||
|
'''
|
||||||
|
Enforce that prefix caching & sliding-window attention (SWA)
|
||||||
|
are currently unsupported *specifically* for encoder/decoder models.
|
||||||
|
|
||||||
|
Raises NotImplementedError if unsupported scenario is detected.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* block_mgr: BlockSpaceManager instance
|
||||||
|
* seq_group: SequenceGroup passed to block_mgr
|
||||||
|
'''
|
||||||
|
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
if _get_block_mgr_sliding_window_attr(block_mgr) is not None:
|
||||||
|
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA)
|
||||||
|
|
||||||
|
if block_mgr.enable_caching:
|
||||||
|
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE)
|
||||||
@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Set, Tuple
|
from typing import Set, Tuple
|
||||||
|
|
||||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||||
|
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||||
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
||||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -255,14 +256,30 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
Device.CPU, block_size, num_cpu_blocks)
|
Device.CPU, block_size, num_cpu_blocks)
|
||||||
# Mapping: seq_id -> BlockTable.
|
# Mapping: seq_id -> BlockTable.
|
||||||
self.block_tables: Dict[int, BlockTable] = {}
|
self.block_tables: Dict[int, BlockTable] = {}
|
||||||
|
# Mapping: req_id -> BlockTable
|
||||||
|
# Note that each SequenceGroup has a unique
|
||||||
|
# request ID
|
||||||
|
self.cross_block_tables: Dict[str, BlockTable] = {}
|
||||||
|
|
||||||
|
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
|
||||||
|
return 0 if seq is None \
|
||||||
|
else len(seq.logical_token_blocks)
|
||||||
|
|
||||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||||
# the same prompt. This may not be true for preempted sequences.
|
# the same prompt. This may not be true for preempted sequences.
|
||||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
|
||||||
num_required_blocks = len(seq.logical_token_blocks)
|
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||||
|
|
||||||
|
self_num_required_blocks = self._get_seq_num_required_blocks(
|
||||||
|
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
|
||||||
|
cross_num_required_blocks = self._get_seq_num_required_blocks(
|
||||||
|
seq_group.get_encoder_seq())
|
||||||
|
num_required_blocks = self_num_required_blocks + \
|
||||||
|
cross_num_required_blocks
|
||||||
|
|
||||||
if self.block_sliding_window is not None:
|
if self.block_sliding_window is not None:
|
||||||
|
|
||||||
num_required_blocks = min(num_required_blocks,
|
num_required_blocks = min(num_required_blocks,
|
||||||
self.block_sliding_window)
|
self.block_sliding_window)
|
||||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
@ -276,11 +293,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
else:
|
else:
|
||||||
return AllocStatus.LATER
|
return AllocStatus.LATER
|
||||||
|
|
||||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
def _allocate_sequence(self, \
|
||||||
# NOTE: Here we assume that all sequences in the group have the same
|
seq: Sequence, \
|
||||||
# prompt.
|
ref_count: int, \
|
||||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
is_encoder_decoder: bool = True) -> BlockTable:
|
||||||
|
|
||||||
# Allocate new physical token blocks that will store the prompt tokens.
|
# Allocate new physical token blocks that will store the prompt tokens.
|
||||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||||
|
|
||||||
@ -290,21 +306,46 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
and logical_idx >= self.block_sliding_window):
|
and logical_idx >= self.block_sliding_window):
|
||||||
block = block_table[logical_idx % self.block_sliding_window]
|
block = block_table[logical_idx % self.block_sliding_window]
|
||||||
# Set the reference counts of the token blocks.
|
# Set the reference counts of the token blocks.
|
||||||
block.ref_count = seq_group.num_seqs()
|
block.ref_count = ref_count
|
||||||
elif self.enable_caching:
|
elif not is_encoder_decoder and self.enable_caching:
|
||||||
block = self.gpu_allocator.allocate(
|
block = self.gpu_allocator.allocate(
|
||||||
seq.hash_of_block(logical_idx),
|
seq.hash_of_block(logical_idx),
|
||||||
seq.num_hashed_tokens_of_block(logical_idx))
|
seq.num_hashed_tokens_of_block(logical_idx))
|
||||||
else:
|
else:
|
||||||
block = self.gpu_allocator.allocate()
|
block = self.gpu_allocator.allocate()
|
||||||
# Set the reference counts of the token blocks.
|
# Set the reference counts of the token blocks.
|
||||||
block.ref_count = seq_group.num_seqs()
|
block.ref_count = ref_count
|
||||||
block_table.append(block)
|
block_table.append(block)
|
||||||
|
|
||||||
# Assign the block table for each sequence.
|
return block_table
|
||||||
|
|
||||||
|
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||||
|
is_encoder_decoder = seq_group.is_encoder_decoder()
|
||||||
|
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||||
|
|
||||||
|
# Allocate decoder sequences
|
||||||
|
#
|
||||||
|
# NOTE: Here we assume that all sequences in the group have the same
|
||||||
|
# decoder prompt.
|
||||||
|
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||||
|
block_table: BlockTable = \
|
||||||
|
self._allocate_sequence(seq,
|
||||||
|
seq_group.num_seqs(),
|
||||||
|
is_encoder_decoder)
|
||||||
|
|
||||||
|
# Assign the self-attention block tables for each sequence.
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||||
self.block_tables[seq.seq_id] = block_table.copy()
|
self.block_tables[seq.seq_id] = block_table.copy()
|
||||||
|
|
||||||
|
# Allocate encoder sequence
|
||||||
|
if is_encoder_decoder:
|
||||||
|
# A SequenceGroup has only a single encoder sequence (at most),
|
||||||
|
# thus allocate with a ref count of 1
|
||||||
|
block_table = self._allocate_sequence(seq_group.get_encoder_seq(),
|
||||||
|
1, is_encoder_decoder)
|
||||||
|
# Assign the cross-attention block table for the SequenceGroup.
|
||||||
|
self.cross_block_tables[seq_group.request_id] = block_table
|
||||||
|
|
||||||
def can_append_slots(self,
|
def can_append_slots(self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int = 0) -> bool:
|
num_lookahead_slots: int = 0) -> bool:
|
||||||
@ -443,13 +484,18 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
|
|
||||||
def _get_physical_blocks(
|
def _get_physical_blocks(
|
||||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||||
|
|
||||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||||
# the sequences in the same group.
|
# the sequences in the same group.
|
||||||
|
request_id = seq_group.request_id
|
||||||
blocks: Set[PhysicalTokenBlock] = set()
|
blocks: Set[PhysicalTokenBlock] = set()
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
blocks.update(self.block_tables[seq.seq_id])
|
blocks.update(self.block_tables[seq.seq_id])
|
||||||
|
# Cross-attention blocks
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
blocks.update(self.cross_block_tables[request_id])
|
||||||
return list(blocks)
|
return list(blocks)
|
||||||
|
|
||||||
def can_swap_in(self,
|
def can_swap_in(self,
|
||||||
@ -457,8 +503,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||||
assert (num_lookahead_slots == 0
|
assert (num_lookahead_slots == 0
|
||||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||||
|
|
||||||
blocks = self._get_physical_blocks(seq_group)
|
blocks = self._get_physical_blocks(seq_group)
|
||||||
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
num_swapped_seqs += 1
|
||||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||||
# at least one free block right after the swap-in.
|
# at least one free block right after the swap-in.
|
||||||
@ -471,70 +520,81 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
else:
|
else:
|
||||||
return AllocStatus.LATER
|
return AllocStatus.LATER
|
||||||
|
|
||||||
|
def _swap_block_table(
|
||||||
|
self, block_table: BlockTable, src_allocator: BlockAllocatorBase,
|
||||||
|
dest_allocator: BlockAllocatorBase,
|
||||||
|
mapping: Dict[PhysicalTokenBlock,
|
||||||
|
PhysicalTokenBlock]) -> BlockTable:
|
||||||
|
new_block_table = []
|
||||||
|
|
||||||
|
for from_block in block_table:
|
||||||
|
if from_block in mapping:
|
||||||
|
to_block = mapping[from_block]
|
||||||
|
to_block.ref_count += 1
|
||||||
|
else:
|
||||||
|
to_block = dest_allocator.allocate(
|
||||||
|
from_block.block_hash, from_block.num_hashed_tokens)
|
||||||
|
mapping[from_block] = to_block
|
||||||
|
new_block_table.append(to_block)
|
||||||
|
# Free the source block swapped in to destination.
|
||||||
|
src_allocator.free(from_block)
|
||||||
|
|
||||||
|
return new_block_table
|
||||||
|
|
||||||
def swap_in(self,
|
def swap_in(self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
|
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
|
||||||
assert (num_lookahead_slots == 0
|
assert (num_lookahead_slots == 0
|
||||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||||
|
|
||||||
|
request_id = seq_group.request_id
|
||||||
|
|
||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
# dict is efficient in lookup `if cpu_block in mapping`
|
# dict is efficient in lookup `if cpu_block in mapping`
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
new_block_table: BlockTable = []
|
self.block_tables[seq.seq_id] = \
|
||||||
block_table = self.block_tables[seq.seq_id]
|
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||||
|
self.cpu_allocator,
|
||||||
|
self.gpu_allocator,
|
||||||
|
mapping)
|
||||||
|
|
||||||
for cpu_block in block_table:
|
if seq_group.is_encoder_decoder():
|
||||||
if cpu_block in mapping:
|
self.cross_block_tables[request_id] = \
|
||||||
gpu_block = mapping[cpu_block]
|
self._swap_block_table(self.cross_block_tables[request_id],
|
||||||
gpu_block.ref_count += 1
|
self.cpu_allocator,
|
||||||
else:
|
self.gpu_allocator,
|
||||||
gpu_block = self.gpu_allocator.allocate(
|
mapping)
|
||||||
cpu_block.block_hash, cpu_block.num_hashed_tokens)
|
|
||||||
mapping[cpu_block] = gpu_block
|
|
||||||
new_block_table.append(gpu_block)
|
|
||||||
# Free the CPU block swapped in to GPU.
|
|
||||||
self.cpu_allocator.free(cpu_block)
|
|
||||||
self.block_tables[seq.seq_id] = new_block_table
|
|
||||||
|
|
||||||
block_number_mapping = {
|
return [(cpu_block.block_number, gpu_block.block_number)
|
||||||
cpu_block.block_number: gpu_block.block_number
|
for cpu_block, gpu_block in mapping.items()]
|
||||||
for cpu_block, gpu_block in mapping.items()
|
|
||||||
}
|
|
||||||
# convert to list of tuples once here
|
|
||||||
return list(block_number_mapping.items())
|
|
||||||
|
|
||||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||||
blocks = self._get_physical_blocks(seq_group)
|
blocks = self._get_physical_blocks(seq_group)
|
||||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||||
|
|
||||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||||
|
request_id = seq_group.request_id
|
||||||
|
|
||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
# dict is efficient in lookup `if gpu_block in mapping`
|
# dict is efficient in lookup `if gpu_block in mapping`
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
new_block_table: BlockTable = []
|
self.block_tables[seq.seq_id] = \
|
||||||
block_table = self.block_tables[seq.seq_id]
|
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||||
|
self.gpu_allocator,
|
||||||
|
self.cpu_allocator,
|
||||||
|
mapping)
|
||||||
|
|
||||||
for gpu_block in block_table:
|
if seq_group.is_encoder_decoder():
|
||||||
if gpu_block in mapping:
|
self.cross_block_tables[request_id] = \
|
||||||
cpu_block = mapping[gpu_block]
|
self._swap_block_table(self.cross_block_tables[request_id],
|
||||||
cpu_block.ref_count += 1
|
self.gpu_allocator,
|
||||||
else:
|
self.cpu_allocator,
|
||||||
cpu_block = self.cpu_allocator.allocate(
|
mapping)
|
||||||
gpu_block.block_hash, gpu_block.num_hashed_tokens)
|
|
||||||
mapping[gpu_block] = cpu_block
|
|
||||||
new_block_table.append(cpu_block)
|
|
||||||
# Free the GPU block swapped out to CPU.
|
|
||||||
self.gpu_allocator.free(gpu_block)
|
|
||||||
self.block_tables[seq.seq_id] = new_block_table
|
|
||||||
|
|
||||||
block_number_mapping = {
|
return [(cpu_block.block_number, gpu_block.block_number)
|
||||||
gpu_block.block_number: cpu_block.block_number
|
for cpu_block, gpu_block in mapping.items()]
|
||||||
for gpu_block, cpu_block in mapping.items()
|
|
||||||
}
|
|
||||||
# convert to list of tuples once here
|
|
||||||
return list(block_number_mapping.items())
|
|
||||||
|
|
||||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||||
# when using a sliding window, each seq will only use up
|
# when using a sliding window, each seq will only use up
|
||||||
@ -559,15 +619,32 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
self._free_block_table(block_table)
|
self._free_block_table(block_table)
|
||||||
del self.block_tables[seq.seq_id]
|
del self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
|
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||||
|
if seq_group.request_id not in self.cross_block_tables:
|
||||||
|
# Already freed or hasn't ben scheduled yet.
|
||||||
|
return
|
||||||
|
block_table = self.cross_block_tables[seq_group.request_id]
|
||||||
|
self._free_block_table(block_table)
|
||||||
|
del self.cross_block_tables[seq_group.request_id]
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
# Free decoder block tables
|
||||||
for block_table in self.block_tables.values():
|
for block_table in self.block_tables.values():
|
||||||
self._free_block_table(block_table)
|
self._free_block_table(block_table)
|
||||||
self.block_tables.clear()
|
self.block_tables.clear()
|
||||||
|
# Free cross-attention block tables
|
||||||
|
for block_table in self.cross_block_tables.values():
|
||||||
|
self._free_block_table(block_table)
|
||||||
|
self.cross_block_tables.clear()
|
||||||
|
|
||||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
return [block.block_number for block in block_table]
|
return [block.block_number for block in block_table]
|
||||||
|
|
||||||
|
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||||
|
block_table = self.cross_block_tables[seq_group.request_id]
|
||||||
|
return [block.block_number for block in block_table]
|
||||||
|
|
||||||
def get_num_free_gpu_blocks(self) -> int:
|
def get_num_free_gpu_blocks(self) -> int:
|
||||||
return self.gpu_allocator.get_num_free_blocks()
|
return self.gpu_allocator.get_num_free_blocks()
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,13 @@ from typing import Tuple
|
|||||||
|
|
||||||
from vllm.core.block.block_table import BlockTable
|
from vllm.core.block.block_table import BlockTable
|
||||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||||
|
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
|
||||||
SeqId = int
|
SeqId = int
|
||||||
|
EncoderSeqId = str
|
||||||
|
|
||||||
|
|
||||||
class BlockSpaceManagerV2(BlockSpaceManager):
|
class BlockSpaceManagerV2(BlockSpaceManager):
|
||||||
@ -94,17 +96,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||||
|
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
|
||||||
|
|
||||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||||
# the same prompt. This may not be true for preempted sequences.
|
# the same prompt. This may not be true for preempted sequences.
|
||||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
|
||||||
|
|
||||||
|
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||||
|
|
||||||
|
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||||
seq.get_token_ids(),
|
seq.get_token_ids(),
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
num_required_blocks += BlockTable.get_num_required_blocks(
|
||||||
|
seq_group.get_encoder_seq().get_token_ids(),
|
||||||
|
block_size=self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
if self.max_block_sliding_window is not None:
|
if self.max_block_sliding_window is not None:
|
||||||
num_required_blocks = min(num_required_blocks,
|
num_required_blocks = min(num_required_blocks,
|
||||||
self.max_block_sliding_window)
|
self.max_block_sliding_window)
|
||||||
@ -121,7 +132,19 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
else:
|
else:
|
||||||
return AllocStatus.LATER
|
return AllocStatus.LATER
|
||||||
|
|
||||||
|
def _allocate_sequence(self, seq: Sequence) -> BlockTable:
|
||||||
|
block_table = BlockTable(
|
||||||
|
block_size=self.block_size,
|
||||||
|
block_allocator=self.block_allocator,
|
||||||
|
max_block_sliding_window=self.max_block_sliding_window,
|
||||||
|
)
|
||||||
|
block_table.allocate(seq.get_token_ids())
|
||||||
|
|
||||||
|
return block_table
|
||||||
|
|
||||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||||
|
|
||||||
|
# Allocate self-attention block tables for decoder sequences
|
||||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||||
& self.block_tables.keys()), "block table already exists"
|
& self.block_tables.keys()), "block table already exists"
|
||||||
@ -129,20 +152,29 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
# NOTE: Here we assume that all sequences in the group have the same
|
# NOTE: Here we assume that all sequences in the group have the same
|
||||||
# prompt.
|
# prompt.
|
||||||
seq = waiting_seqs[0]
|
seq = waiting_seqs[0]
|
||||||
|
block_table: BlockTable = self._allocate_sequence(seq)
|
||||||
block_table = BlockTable(
|
|
||||||
block_size=self.block_size,
|
|
||||||
block_allocator=self.block_allocator,
|
|
||||||
max_block_sliding_window=self.max_block_sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
block_table.allocate(seq.get_token_ids())
|
|
||||||
self.block_tables[seq.seq_id] = block_table
|
self.block_tables[seq.seq_id] = block_table
|
||||||
|
|
||||||
# Assign the block table for each sequence.
|
# Assign the block table for each sequence.
|
||||||
for seq in waiting_seqs[1:]:
|
for seq in waiting_seqs[1:]:
|
||||||
self.block_tables[seq.seq_id] = block_table.fork()
|
self.block_tables[seq.seq_id] = block_table.fork()
|
||||||
|
|
||||||
|
# Allocate cross-attention block table for encoder sequence
|
||||||
|
#
|
||||||
|
# NOTE: Here we assume that all sequences in the group have the same
|
||||||
|
# encoder prompt.
|
||||||
|
request_id = seq_group.request_id
|
||||||
|
|
||||||
|
assert (request_id
|
||||||
|
not in self.cross_block_tables), \
|
||||||
|
"block table already exists"
|
||||||
|
|
||||||
|
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
|
||||||
|
|
||||||
|
if seq_group.is_encoder_decoder():
|
||||||
|
block_table = self._allocate_sequence(seq_group.get_encoder_seq())
|
||||||
|
self.cross_block_tables[request_id] = block_table
|
||||||
|
|
||||||
def can_append_slots(self, seq_group: SequenceGroup,
|
def can_append_slots(self, seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int) -> bool:
|
num_lookahead_slots: int) -> bool:
|
||||||
"""Determine if there is enough space in the GPU KV cache to continue
|
"""Determine if there is enough space in the GPU KV cache to continue
|
||||||
@ -197,12 +229,27 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
self.block_tables[seq.seq_id].free()
|
self.block_tables[seq.seq_id].free()
|
||||||
del self.block_tables[seq.seq_id]
|
del self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
|
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||||
|
request_id = seq_group.request_id
|
||||||
|
if request_id not in self.cross_block_tables:
|
||||||
|
# Already freed or hasn't been scheduled yet.
|
||||||
|
return
|
||||||
|
self.cross_block_tables[request_id].free()
|
||||||
|
del self.cross_block_tables[request_id]
|
||||||
|
|
||||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||||
assert seq.seq_id in self.block_tables
|
assert seq.seq_id in self.block_tables
|
||||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||||
assert all(b is not None for b in block_ids)
|
assert all(b is not None for b in block_ids)
|
||||||
return block_ids # type: ignore
|
return block_ids # type: ignore
|
||||||
|
|
||||||
|
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||||
|
request_id = seq_group.request_id
|
||||||
|
assert request_id in self.cross_block_tables
|
||||||
|
block_ids = self.cross_block_tables[request_id].physical_block_ids
|
||||||
|
assert all(b is not None for b in block_ids)
|
||||||
|
return block_ids # type: ignore
|
||||||
|
|
||||||
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||||
# Update the last accessed time of all the blocks accessed
|
# Update the last accessed time of all the blocks accessed
|
||||||
# in this step.
|
# in this step.
|
||||||
|
|||||||
@ -430,6 +430,8 @@ class SequenceGroup:
|
|||||||
for an embedding model.
|
for an embedding model.
|
||||||
pooling_params: The pooling parameters used to generate the pooling
|
pooling_params: The pooling parameters used to generate the pooling
|
||||||
for an embedding model.
|
for an embedding model.
|
||||||
|
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||||
|
unless you are working with an encoder/decoder model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -441,6 +443,7 @@ class SequenceGroup:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
embeddings: Optional[List[float]] = None,
|
embeddings: Optional[List[float]] = None,
|
||||||
pooling_params: Optional[PoolingParams] = None,
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
|
encoder_seq: Optional[Sequence] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
@ -455,6 +458,7 @@ class SequenceGroup:
|
|||||||
self.state = SequenceGroupState()
|
self.state = SequenceGroupState()
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.pooling_params = pooling_params
|
self.pooling_params = pooling_params
|
||||||
|
self.encoder_seq = encoder_seq
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
@ -538,6 +542,12 @@ class SequenceGroup:
|
|||||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def is_encoder_decoder(self) -> bool:
|
||||||
|
return self.encoder_seq is not None
|
||||||
|
|
||||||
|
def get_encoder_seq(self) -> Optional[Sequence]:
|
||||||
|
return self.encoder_seq
|
||||||
|
|
||||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||||
return [
|
return [
|
||||||
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
||||||
@ -621,6 +631,15 @@ class SequenceGroupMetadata:
|
|||||||
used in prefix caching.
|
used in prefix caching.
|
||||||
state: Internal state tied to this sequence group.
|
state: Internal state tied to this sequence group.
|
||||||
multi_modal_data: Multi modal data.
|
multi_modal_data: Multi modal data.
|
||||||
|
encoder_seq_data: Optional sequence data for encoder prompt
|
||||||
|
(SequenceGroup.encoder_seq). Should be None
|
||||||
|
unless you are working with an encoder/decoder
|
||||||
|
model.
|
||||||
|
cross_block_table: Optional cross-attention block table associated
|
||||||
|
with the encoder prompt
|
||||||
|
(SequenceGroup.encoder_seq). Should be None
|
||||||
|
unless you are working with an encoder/decoder
|
||||||
|
model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -637,6 +656,8 @@ class SequenceGroupMetadata:
|
|||||||
computed_block_nums: Optional[List[int]] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
state: Optional[SequenceGroupState] = None,
|
state: Optional[SequenceGroupState] = None,
|
||||||
multi_modal_data: Optional[MultiModalData] = None,
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
encoder_seq_data: Optional[SequenceData] = None,
|
||||||
|
cross_block_table: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
@ -648,6 +669,8 @@ class SequenceGroupMetadata:
|
|||||||
self.computed_block_nums = computed_block_nums
|
self.computed_block_nums = computed_block_nums
|
||||||
self.multi_modal_data = multi_modal_data
|
self.multi_modal_data = multi_modal_data
|
||||||
self.state = SequenceGroupState() if state is None else state
|
self.state = SequenceGroupState() if state is None else state
|
||||||
|
self.encoder_seq_data = encoder_seq_data
|
||||||
|
self.cross_block_table = cross_block_table
|
||||||
self._token_chunk_size = token_chunk_size
|
self._token_chunk_size = token_chunk_size
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user