mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 23:45:02 +08:00
Add Automatic Prefix Caching (#2762)
Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
baee28c46c
commit
ce4f5a29fb
@ -73,10 +73,10 @@ def run_vllm(
|
|||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
enable_prefix_caching: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
llm = LLM(
|
llm = LLM(model=model,
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
quantization=quantization,
|
quantization=quantization,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
@ -87,7 +87,7 @@ def run_vllm(
|
|||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
enable_prefix_caching=enable_prefix_caching)
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
for prompt, _, output_len in requests:
|
for prompt, _, output_len in requests:
|
||||||
@ -211,7 +211,8 @@ def main(args: argparse.Namespace):
|
|||||||
args.seed, args.n, args.use_beam_search,
|
args.seed, args.n, args.use_beam_search,
|
||||||
args.trust_remote_code, args.dtype,
|
args.trust_remote_code, args.dtype,
|
||||||
args.max_model_len, args.enforce_eager,
|
args.max_model_len, args.enforce_eager,
|
||||||
args.kv_cache_dtype, args.device)
|
args.kv_cache_dtype, args.device,
|
||||||
|
args.enable_prefix_caching)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
@ -302,6 +303,7 @@ if __name__ == "__main__":
|
|||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda"],
|
choices=["cuda"],
|
||||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||||
|
parser.add_argument("--enable_prefix_caching", action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM:
|
|||||||
|
|
||||||
Token block size for contiguous chunks of tokens.
|
Token block size for contiguous chunks of tokens.
|
||||||
|
|
||||||
|
.. option:: --enable-prefix-caching
|
||||||
|
|
||||||
|
Enables automatic prefix caching
|
||||||
|
|
||||||
.. option:: --seed <seed>
|
.. option:: --seed <seed>
|
||||||
|
|
||||||
Random seed for operations.
|
Random seed for operations.
|
||||||
|
|||||||
@ -37,20 +37,13 @@ for output in outputs:
|
|||||||
|
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
# -1 since the last token can change when concatenating prompts.
|
|
||||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
|
||||||
|
|
||||||
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
|
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
|
||||||
# The prefix will only be cached after the first batch is processed, so we need to call generate once
|
# The prefix will only be cached after the first batch is processed, so we need to call generate once
|
||||||
# to calculate the prefix and cache it.
|
# to calculate the prefix and cache it.
|
||||||
outputs = llm.generate(generating_prompts[0],
|
outputs = llm.generate(generating_prompts[0], sampling_params)
|
||||||
sampling_params,
|
|
||||||
prefix_pos=[prefix_pos])
|
|
||||||
|
|
||||||
# Subsequent batches can leverage the cached prefix
|
# Subsequent batches can leverage the cached prefix
|
||||||
outputs = llm.generate(generating_prompts,
|
outputs = llm.generate(generating_prompts, sampling_params)
|
||||||
sampling_params,
|
|
||||||
prefix_pos=[prefix_pos] * len(generating_prompts))
|
|
||||||
|
|
||||||
# Print the outputs. You should see the same outputs as before
|
# Print the outputs. You should see the same outputs as before
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
|||||||
@ -4,38 +4,73 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm.core.block_manager import BlockAllocator
|
||||||
|
from vllm.utils import Device
|
||||||
prefix = (
|
|
||||||
"You are an expert school principal, skilled in effectively managing "
|
|
||||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
|
||||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
|
||||||
"community, joyful discovery, and life-long learning. The candidate is "
|
|
||||||
"coming in for a first-round panel interview for a 8th grade Math "
|
|
||||||
"teaching role. They have 5 years of previous teaching experience "
|
|
||||||
"as an assistant teacher at a co-ed, public school with experience "
|
|
||||||
"in middle school math teaching. Based on these information, fulfill "
|
|
||||||
"the following paragraph: ")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
@pytest.mark.parametrize("max_tokens", [16])
|
@pytest.mark.parametrize("num_blocks", [16])
|
||||||
def test_prefix_caching(
|
def test_block_allocator(
|
||||||
example_prompts,
|
block_size: int,
|
||||||
model: str,
|
num_blocks: int,
|
||||||
max_tokens: int,
|
|
||||||
):
|
):
|
||||||
llm = LLM(model=model)
|
block_hash = 1
|
||||||
# -1 since the last token can change when concatenating prompts.
|
block_allocator = BlockAllocator(Device.CPU,
|
||||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
block_size,
|
||||||
prompts = [prefix + prompt for prompt in example_prompts]
|
num_blocks,
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
enable_caching=True)
|
||||||
outputs_without_prefix = llm.generate(prompts, sampling_params)
|
|
||||||
outputs_with_prefix = llm.generate(prompts,
|
# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
|
||||||
sampling_params,
|
first_block = block_allocator.allocate(block_hash, 0)
|
||||||
prefix_pos=[prefix_pos] * len(prompts))
|
second_block = block_allocator.allocate(block_hash, 0)
|
||||||
for output_without_prefix, output_with_prefix in zip(
|
assert (first_block == second_block)
|
||||||
outputs_without_prefix, outputs_with_prefix):
|
assert (second_block.ref_count == 2)
|
||||||
assert (output_without_prefix.outputs[0].token_ids ==
|
|
||||||
output_with_prefix.outputs[0].token_ids)
|
# Free the first_block and confirm that the ref_count is correctly decremented on the second block
|
||||||
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
|
block_allocator.free(first_block)
|
||||||
|
assert (second_block.ref_count == 1)
|
||||||
|
|
||||||
|
# Free the second block
|
||||||
|
block_allocator.free(second_block)
|
||||||
|
|
||||||
|
# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
|
||||||
|
first_block = block_allocator.allocate(block_hash, 0)
|
||||||
|
assert (first_block == second_block)
|
||||||
|
assert (first_block.block_hash == block_hash)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_blocks", [16])
|
||||||
|
def test_eviction(num_blocks: int, ):
|
||||||
|
block_size = 16
|
||||||
|
block_allocator = BlockAllocator(Device.CPU,
|
||||||
|
block_size,
|
||||||
|
num_blocks,
|
||||||
|
enable_caching=True)
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
# use i as the block_hash
|
||||||
|
blocks.append(block_allocator.allocate(i, 0))
|
||||||
|
|
||||||
|
#Free all blocks
|
||||||
|
for block in blocks:
|
||||||
|
block_allocator.free(block)
|
||||||
|
|
||||||
|
# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
|
||||||
|
new_block_hash = block_size
|
||||||
|
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||||
|
assert (new_block == blocks[0])
|
||||||
|
assert (new_block.block_hash == new_block_hash)
|
||||||
|
|
||||||
|
# Reallocate the second in blocks to remove it from the free list
|
||||||
|
realloc_block_hash = 1
|
||||||
|
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
|
||||||
|
assert (realloc_block == blocks[realloc_block_hash])
|
||||||
|
assert (realloc_block.block_hash == realloc_block_hash)
|
||||||
|
|
||||||
|
# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
|
||||||
|
new_block_hash = block_size + 1
|
||||||
|
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||||
|
assert (realloc_block != new_block)
|
||||||
|
assert (new_block.block_hash == new_block_hash)
|
||||||
|
assert (new_block.block_number == 2)
|
||||||
|
|||||||
76
tests/test_cache_block_hashing.py
Normal file
76
tests/test_cache_block_hashing.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""Test hashing of cache blocks.
|
||||||
|
|
||||||
|
Run `pytest tests/test_cache_block_hashing.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import TokenizerGroup
|
||||||
|
from vllm.sequence import Sequence
|
||||||
|
|
||||||
|
# Make two prefixes with different first blocks.
|
||||||
|
prefix_start = [("You are an expert"), ("You are a")]
|
||||||
|
prefix_common = (
|
||||||
|
" school principal, skilled in effectively managing "
|
||||||
|
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||||
|
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||||
|
"community, joyful discovery, and life-long learning. The candidate is "
|
||||||
|
"coming in for a first-round panel interview for a 8th grade Math "
|
||||||
|
"teaching role. They have 5 years of previous teaching experience "
|
||||||
|
"as an assistant teacher at a co-ed, public school with experience "
|
||||||
|
"in middle school math teaching. Based on this, fulfill "
|
||||||
|
"the following: ")
|
||||||
|
prefixes = [start + prefix_common for start in prefix_start]
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
sample_prompts = [
|
||||||
|
"Hello, my name is", "The president of the United States is",
|
||||||
|
"The capital of France is", "The future of AI is"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function.
|
||||||
|
def flatten_2d(li):
|
||||||
|
return [lss for ls in li for lss in ls]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("max_num_seqs", [256])
|
||||||
|
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
|
||||||
|
|
||||||
|
tokenizer = TokenizerGroup(
|
||||||
|
tokenizer_id="facebook/opt-125m",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
hashes = []
|
||||||
|
|
||||||
|
for prefix in prefixes:
|
||||||
|
hashes.append([])
|
||||||
|
prompts = [prefix + prompt for prompt in sample_prompts]
|
||||||
|
seq_id = 0
|
||||||
|
for prompt in prompts:
|
||||||
|
hashes[-1].append([])
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||||
|
|
||||||
|
num_blocks = len(prompt_token_ids) // block_size
|
||||||
|
for idx in range(num_blocks):
|
||||||
|
hashes[-1][-1].append(seq.hash_of_block(idx))
|
||||||
|
|
||||||
|
seq_id += 1
|
||||||
|
|
||||||
|
# Check that hashes made with two prefixes with different first blocks are
|
||||||
|
# different everywhere.
|
||||||
|
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
|
||||||
|
assert (hash0 != hash1)
|
||||||
|
|
||||||
|
# Check that hashes of different prompts made with the same prefix are the
|
||||||
|
# same until the hashes that contain the prompt.
|
||||||
|
for hash_pref in hashes:
|
||||||
|
same_hashes = [tuple(h[:-1]) for h in hash_pref]
|
||||||
|
different_hashes = [h[-1] for h in hash_pref]
|
||||||
|
assert (len(set(same_hashes)) == 1)
|
||||||
|
assert (len(set(different_hashes)) == len(different_hashes))
|
||||||
@ -5,6 +5,8 @@ from vllm.utils import Device
|
|||||||
|
|
||||||
_BLANK_TOKEN_ID = -1
|
_BLANK_TOKEN_ID = -1
|
||||||
|
|
||||||
|
DEFAULT_LAST_ACCESSED_TIME = -1
|
||||||
|
|
||||||
|
|
||||||
class LogicalTokenBlock:
|
class LogicalTokenBlock:
|
||||||
"""A block that stores a contiguous chunk of tokens from left to right.
|
"""A block that stores a contiguous chunk of tokens from left to right.
|
||||||
@ -55,17 +57,27 @@ class PhysicalTokenBlock:
|
|||||||
device: Device,
|
device: Device,
|
||||||
block_number: int,
|
block_number: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
block_hash: int,
|
||||||
|
num_hashed_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.block_number = block_number
|
self.block_number = block_number
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
self.block_hash = block_hash
|
||||||
|
self.num_hashed_tokens = num_hashed_tokens
|
||||||
|
|
||||||
self.ref_count = 0
|
self.ref_count = 0
|
||||||
|
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
|
||||||
|
|
||||||
|
self.computed = False
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f'PhysicalTokenBlock(device={self.device}, '
|
return (f'PhysicalTokenBlock(device={self.device}, '
|
||||||
f'block_number={self.block_number}, '
|
f'block_number={self.block_number}, '
|
||||||
f'ref_count={self.ref_count})')
|
f'num_hashed_tokens={self.num_hashed_tokens}, '
|
||||||
|
f'ref_count={self.ref_count}, '
|
||||||
|
f'last_accessed={self.last_accessed}, '
|
||||||
|
f'computed={self.computed})')
|
||||||
|
|
||||||
|
|
||||||
# Mapping: logical block number -> physical block.
|
# Mapping: logical block number -> physical block.
|
||||||
|
|||||||
@ -303,12 +303,14 @@ class CacheConfig:
|
|||||||
swap_space: int,
|
swap_space: int,
|
||||||
cache_dtype: str,
|
cache_dtype: str,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
enable_prefix_caching: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
self.swap_space_bytes = swap_space * _GB
|
self.swap_space_bytes = swap_space * _GB
|
||||||
self.cache_dtype = cache_dtype
|
self.cache_dtype = cache_dtype
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
self._verify_cache_dtype()
|
self._verify_cache_dtype()
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
"""A block manager that manages token blocks."""
|
"""A block manager that manages token blocks."""
|
||||||
import enum
|
import enum
|
||||||
|
from itertools import count
|
||||||
|
from os.path import commonprefix
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
|
||||||
|
|
||||||
|
|
||||||
class BlockAllocator:
|
class BlockAllocator:
|
||||||
@ -15,29 +18,68 @@ class BlockAllocator:
|
|||||||
the reference count becomes zero, the block is added back to the free list.
|
the reference count becomes zero, the block is added back to the free list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
device: Device,
|
device: Device,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> None:
|
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||||
|
enable_caching: bool = False) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
self.enable_caching = enable_caching
|
||||||
|
|
||||||
# Initialize the free blocks.
|
self.current_num_blocks = 0
|
||||||
self.free_blocks: BlockTable = []
|
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
|
||||||
for i in range(num_blocks):
|
|
||||||
block = PhysicalTokenBlock(device=device,
|
|
||||||
block_number=i,
|
|
||||||
block_size=block_size)
|
|
||||||
self.free_blocks.append(block)
|
|
||||||
|
|
||||||
def allocate(self) -> PhysicalTokenBlock:
|
# Switch over to FIFO eviction when caching is disabled
|
||||||
if not self.free_blocks:
|
if not self.enable_caching:
|
||||||
raise ValueError("Out of memory! No free blocks are available.")
|
eviction_policy = EvictionPolicy.FIFO
|
||||||
block = self.free_blocks.pop()
|
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||||
block.ref_count = 1
|
|
||||||
|
self.default_hash_ctr = count()
|
||||||
|
|
||||||
|
def allocate_block(self, block_hash: int,
|
||||||
|
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
||||||
|
if self.current_num_blocks == self.num_blocks:
|
||||||
|
block = self.evictor.evict()
|
||||||
|
block.block_hash = block_hash
|
||||||
|
block.num_hashed_tokens = num_hashed_tokens
|
||||||
|
return block
|
||||||
|
block = PhysicalTokenBlock(device=self.device,
|
||||||
|
block_number=self.current_num_blocks,
|
||||||
|
block_size=self.block_size,
|
||||||
|
block_hash=block_hash,
|
||||||
|
num_hashed_tokens=num_hashed_tokens)
|
||||||
|
self.current_num_blocks += 1
|
||||||
|
return block
|
||||||
|
|
||||||
|
def allocate(self,
|
||||||
|
block_hash: Optional[int] = None,
|
||||||
|
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||||
|
# If caching is disabled, just allocate a new block and return it
|
||||||
|
if not self.enable_caching:
|
||||||
|
block = self.allocate_block(next(self.default_hash_ctr),
|
||||||
|
num_hashed_tokens)
|
||||||
|
block.ref_count += 1
|
||||||
|
return block
|
||||||
|
|
||||||
|
if block_hash is None:
|
||||||
|
block_hash = next(self.default_hash_ctr)
|
||||||
|
if block_hash in self.evictor:
|
||||||
|
assert block_hash not in self.cached_blocks
|
||||||
|
block = self.evictor.remove(block_hash)
|
||||||
|
assert block.ref_count == 0
|
||||||
|
self.cached_blocks[block_hash] = block
|
||||||
|
block.ref_count += 1
|
||||||
|
assert block.block_hash == block_hash
|
||||||
|
return block
|
||||||
|
if block_hash not in self.cached_blocks:
|
||||||
|
self.cached_blocks[block_hash] = self.allocate_block(
|
||||||
|
block_hash, num_hashed_tokens)
|
||||||
|
block = self.cached_blocks[block_hash]
|
||||||
|
assert block.block_hash == block_hash
|
||||||
|
block.ref_count += 1
|
||||||
return block
|
return block
|
||||||
|
|
||||||
def free(self, block: PhysicalTokenBlock) -> None:
|
def free(self, block: PhysicalTokenBlock) -> None:
|
||||||
@ -45,10 +87,27 @@ class BlockAllocator:
|
|||||||
raise ValueError(f"Double free! {block} is already freed.")
|
raise ValueError(f"Double free! {block} is already freed.")
|
||||||
block.ref_count -= 1
|
block.ref_count -= 1
|
||||||
if block.ref_count == 0:
|
if block.ref_count == 0:
|
||||||
self.free_blocks.append(block)
|
assert block.block_hash not in self.evictor
|
||||||
|
self.evictor.add(block)
|
||||||
|
|
||||||
|
# If caching is enabled, remove the block from the cached_blocks
|
||||||
|
if self.enable_caching:
|
||||||
|
del self.cached_blocks[block.block_hash]
|
||||||
|
|
||||||
def get_num_free_blocks(self) -> int:
|
def get_num_free_blocks(self) -> int:
|
||||||
return len(self.free_blocks)
|
return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks
|
||||||
|
|
||||||
|
def contains_block(self, block_hash: int) -> bool:
|
||||||
|
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||||
|
|
||||||
|
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||||
|
# If caching is enabled, update the hash of block and the cached_blocks dictionary.
|
||||||
|
if self.enable_caching:
|
||||||
|
assert not self.contains_block(block_hash)
|
||||||
|
old_hash = block.block_hash
|
||||||
|
block.block_hash = block_hash
|
||||||
|
del self.cached_blocks[old_hash]
|
||||||
|
self.cached_blocks[block_hash] = block
|
||||||
|
|
||||||
|
|
||||||
class AllocStatus(enum.Enum):
|
class AllocStatus(enum.Enum):
|
||||||
@ -75,6 +134,7 @@ class BlockSpaceManager:
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
watermark: float = 0.01,
|
watermark: float = 0.01,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
enable_caching: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_total_gpu_blocks = num_gpu_blocks
|
self.num_total_gpu_blocks = num_gpu_blocks
|
||||||
@ -89,11 +149,17 @@ class BlockSpaceManager:
|
|||||||
self.watermark = watermark
|
self.watermark = watermark
|
||||||
assert watermark >= 0.0
|
assert watermark >= 0.0
|
||||||
|
|
||||||
|
self.enable_caching = enable_caching
|
||||||
|
|
||||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||||
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
|
self.gpu_allocator = BlockAllocator(Device.GPU,
|
||||||
num_gpu_blocks)
|
block_size,
|
||||||
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
|
num_gpu_blocks,
|
||||||
num_cpu_blocks)
|
enable_caching=enable_caching)
|
||||||
|
self.cpu_allocator = BlockAllocator(Device.CPU,
|
||||||
|
block_size,
|
||||||
|
num_cpu_blocks,
|
||||||
|
enable_caching=enable_caching)
|
||||||
# Mapping: seq_id -> BlockTable.
|
# Mapping: seq_id -> BlockTable.
|
||||||
self.block_tables: Dict[int, BlockTable] = {}
|
self.block_tables: Dict[int, BlockTable] = {}
|
||||||
|
|
||||||
@ -103,9 +169,6 @@ class BlockSpaceManager:
|
|||||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||||
num_required_blocks = len(seq.logical_token_blocks)
|
num_required_blocks = len(seq.logical_token_blocks)
|
||||||
|
|
||||||
if seq_group.prefix is not None and seq_group.prefix.allocated:
|
|
||||||
num_required_blocks -= seq_group.prefix.get_num_blocks()
|
|
||||||
|
|
||||||
if self.block_sliding_window is not None:
|
if self.block_sliding_window is not None:
|
||||||
num_required_blocks = min(num_required_blocks,
|
num_required_blocks = min(num_required_blocks,
|
||||||
self.block_sliding_window)
|
self.block_sliding_window)
|
||||||
@ -129,36 +192,16 @@ class BlockSpaceManager:
|
|||||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||||
|
|
||||||
block_table: BlockTable = []
|
block_table: BlockTable = []
|
||||||
prefix_block_table: BlockTable = []
|
|
||||||
num_prefix_blocks = 0
|
|
||||||
|
|
||||||
prefix = seq_group.prefix
|
|
||||||
if prefix is not None and prefix.allocated:
|
|
||||||
# Prefix has already been allocated. Use the existing block table.
|
|
||||||
num_prompt_blocks -= prefix.get_num_blocks()
|
|
||||||
for block in prefix.block_table:
|
|
||||||
block.ref_count += seq_group.num_seqs()
|
|
||||||
block_table.append(block)
|
|
||||||
|
|
||||||
for logical_idx in range(num_prompt_blocks):
|
for logical_idx in range(num_prompt_blocks):
|
||||||
if (self.block_sliding_window is not None
|
if (self.block_sliding_window is not None
|
||||||
and logical_idx >= self.block_sliding_window):
|
and logical_idx >= self.block_sliding_window):
|
||||||
block = block_table[logical_idx % self.block_sliding_window]
|
block = block_table[logical_idx % self.block_sliding_window]
|
||||||
else:
|
else:
|
||||||
block = self.gpu_allocator.allocate()
|
block = self.gpu_allocator.allocate(
|
||||||
# Set the reference counts of the token blocks.
|
seq.hash_of_block(logical_idx),
|
||||||
block.ref_count = seq_group.num_seqs()
|
seq.num_hashed_tokens_of_block(logical_idx))
|
||||||
block_table.append(block)
|
block_table.append(block)
|
||||||
|
|
||||||
if prefix is not None and not prefix.allocated:
|
|
||||||
# Allocate blocks for the prefix, we will compute the prefix's
|
|
||||||
# KV cache in this run.
|
|
||||||
num_prefix_blocks = prefix.get_num_blocks()
|
|
||||||
prefix_block_table = block_table[:num_prefix_blocks]
|
|
||||||
for block in prefix_block_table:
|
|
||||||
block.ref_count += 1
|
|
||||||
prefix.set_block_table(prefix_block_table)
|
|
||||||
|
|
||||||
# Assign the block table for each sequence.
|
# Assign the block table for each sequence.
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||||
self.block_tables[seq.seq_id] = block_table.copy()
|
self.block_tables[seq.seq_id] = block_table.copy()
|
||||||
@ -170,12 +213,72 @@ class BlockSpaceManager:
|
|||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
return num_seqs <= num_free_gpu_blocks
|
return num_seqs <= num_free_gpu_blocks
|
||||||
|
|
||||||
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
def _promote_last_block(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
last_block: PhysicalTokenBlock,
|
||||||
|
) -> PhysicalTokenBlock:
|
||||||
|
# Compute a new hash for the block so that it can be shared by other Sequences
|
||||||
|
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||||
|
|
||||||
|
# if new_hash is already in the cached table, then free last_block and return the cached version
|
||||||
|
if self.gpu_allocator.contains_block(new_hash):
|
||||||
|
self.gpu_allocator.free(last_block)
|
||||||
|
return self.gpu_allocator.allocate(new_hash)
|
||||||
|
else:
|
||||||
|
self.gpu_allocator.update_hash(new_hash, last_block)
|
||||||
|
return last_block
|
||||||
|
|
||||||
|
def _is_last_block_full(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
) -> bool:
|
||||||
|
token_ids_len = len(seq.data.get_token_ids())
|
||||||
|
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||||
|
|
||||||
|
def _is_last_block(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
index: int,
|
||||||
|
) -> bool:
|
||||||
|
return index == len(seq.logical_token_blocks) - 1
|
||||||
|
|
||||||
|
def _maybe_promote_last_block(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
last_block: PhysicalTokenBlock,
|
||||||
|
) -> PhysicalTokenBlock:
|
||||||
|
if self._is_last_block_full(seq):
|
||||||
|
return self._promote_last_block(seq, last_block)
|
||||||
|
else:
|
||||||
|
return last_block
|
||||||
|
|
||||||
|
def _allocate_last_physical_block(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
) -> PhysicalTokenBlock:
|
||||||
|
block_hash: Optional[int] = None
|
||||||
|
if (self._is_last_block_full(seq)):
|
||||||
|
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||||
|
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
||||||
|
len(seq.logical_token_blocks) - 1)
|
||||||
|
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
|
||||||
|
if block_hash is None:
|
||||||
|
assert new_block.ref_count == 1
|
||||||
|
return new_block
|
||||||
|
|
||||||
|
def append_slot(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
) -> Optional[Tuple[int, int]]:
|
||||||
"""Allocate a physical slot for a new token."""
|
"""Allocate a physical slot for a new token."""
|
||||||
logical_blocks = seq.logical_token_blocks
|
logical_blocks = seq.logical_token_blocks
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
# If we need to allocate a new physical block
|
||||||
if len(block_table) < len(logical_blocks):
|
if len(block_table) < len(logical_blocks):
|
||||||
|
# Currently this code only supports adding one physical block
|
||||||
|
assert len(block_table) == len(logical_blocks) - 1
|
||||||
|
|
||||||
if (self.block_sliding_window
|
if (self.block_sliding_window
|
||||||
and len(block_table) >= self.block_sliding_window):
|
and len(block_table) >= self.block_sliding_window):
|
||||||
# reuse a block
|
# reuse a block
|
||||||
@ -184,8 +287,8 @@ class BlockSpaceManager:
|
|||||||
else:
|
else:
|
||||||
# The sequence has a new logical block.
|
# The sequence has a new logical block.
|
||||||
# Allocate a new physical block.
|
# Allocate a new physical block.
|
||||||
block = self.gpu_allocator.allocate()
|
new_block = self._allocate_last_physical_block(seq)
|
||||||
block_table.append(block)
|
block_table.append(new_block)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# We want to append the token to the last physical block.
|
# We want to append the token to the last physical block.
|
||||||
@ -193,11 +296,15 @@ class BlockSpaceManager:
|
|||||||
assert last_block.device == Device.GPU
|
assert last_block.device == Device.GPU
|
||||||
if last_block.ref_count == 1:
|
if last_block.ref_count == 1:
|
||||||
# Not shared with other sequences. Appendable.
|
# Not shared with other sequences. Appendable.
|
||||||
|
# If the last block is now complete, promote it to a full block so that it can be shared
|
||||||
|
new_block = self._maybe_promote_last_block(seq, last_block)
|
||||||
|
block_table[-1] = new_block
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
# The last block is shared with other sequences.
|
# The last block is shared with other sequences.
|
||||||
# Copy on Write: Allocate a new block and copy the tokens.
|
# Copy on Write: Allocate a new block and copy the tokens.
|
||||||
new_block = self.gpu_allocator.allocate()
|
new_block = self._allocate_last_physical_block(seq)
|
||||||
|
|
||||||
block_table[-1] = new_block
|
block_table[-1] = new_block
|
||||||
self.gpu_allocator.free(last_block)
|
self.gpu_allocator.free(last_block)
|
||||||
return last_block.block_number, new_block.block_number
|
return last_block.block_number, new_block.block_number
|
||||||
@ -233,25 +340,18 @@ class BlockSpaceManager:
|
|||||||
|
|
||||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
if seq_group.prefix is not None:
|
|
||||||
# make sure to swap in the prefix first
|
|
||||||
assert seq_group.prefix.allocated and seq_group.prefix.computed
|
|
||||||
|
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
if seq_group.prefix is not None:
|
|
||||||
for block in seq_group.prefix.block_table:
|
|
||||||
new_block_table.append(block)
|
|
||||||
block.ref_count += 1
|
|
||||||
|
|
||||||
for cpu_block in block_table:
|
for cpu_block in block_table:
|
||||||
if cpu_block in mapping:
|
if cpu_block in mapping:
|
||||||
gpu_block = mapping[cpu_block]
|
gpu_block = mapping[cpu_block]
|
||||||
gpu_block.ref_count += 1
|
gpu_block.ref_count += 1
|
||||||
else:
|
else:
|
||||||
gpu_block = self.gpu_allocator.allocate()
|
gpu_block = self.gpu_allocator.allocate(
|
||||||
|
cpu_block.block_hash, cpu_block.num_hashed_tokens)
|
||||||
mapping[cpu_block] = gpu_block
|
mapping[cpu_block] = gpu_block
|
||||||
new_block_table.append(gpu_block)
|
new_block_table.append(gpu_block)
|
||||||
# Free the CPU block swapped in to GPU.
|
# Free the CPU block swapped in to GPU.
|
||||||
@ -276,17 +376,12 @@ class BlockSpaceManager:
|
|||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
for gpu_block in block_table:
|
for gpu_block in block_table:
|
||||||
if (seq_group.prefix is not None
|
|
||||||
and gpu_block in seq_group.prefix.block_table):
|
|
||||||
# NOTE: We do not swap out the prefix blocks for now.
|
|
||||||
self.gpu_allocator.free(gpu_block)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gpu_block in mapping:
|
if gpu_block in mapping:
|
||||||
cpu_block = mapping[gpu_block]
|
cpu_block = mapping[gpu_block]
|
||||||
cpu_block.ref_count += 1
|
cpu_block.ref_count += 1
|
||||||
else:
|
else:
|
||||||
cpu_block = self.cpu_allocator.allocate()
|
cpu_block = self.cpu_allocator.allocate(
|
||||||
|
gpu_block.block_hash, gpu_block.num_hashed_tokens)
|
||||||
mapping[gpu_block] = cpu_block
|
mapping[gpu_block] = cpu_block
|
||||||
new_block_table.append(cpu_block)
|
new_block_table.append(cpu_block)
|
||||||
# Free the GPU block swapped out to CPU.
|
# Free the GPU block swapped out to CPU.
|
||||||
@ -328,3 +423,49 @@ class BlockSpaceManager:
|
|||||||
|
|
||||||
def get_num_free_cpu_blocks(self) -> int:
|
def get_num_free_cpu_blocks(self) -> int:
|
||||||
return self.cpu_allocator.get_num_free_blocks()
|
return self.cpu_allocator.get_num_free_blocks()
|
||||||
|
|
||||||
|
def access_all_blocks_in_seq(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
access_time: float,
|
||||||
|
) -> None:
|
||||||
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
for block in block_table:
|
||||||
|
block.last_accessed = access_time
|
||||||
|
|
||||||
|
def compute_last_full_block_in_seq(self, seq: Sequence):
|
||||||
|
if seq.seq_id not in self.block_tables:
|
||||||
|
return
|
||||||
|
max_full_block = seq.get_len() // seq.block_size - 1
|
||||||
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
if max_full_block == -1:
|
||||||
|
return
|
||||||
|
block_table[max_full_block].computed = True
|
||||||
|
|
||||||
|
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
|
||||||
|
if seq.seq_id not in self.block_tables:
|
||||||
|
return []
|
||||||
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
for block_idx in reversed(range(len(block_table))):
|
||||||
|
if block_table[block_idx].computed:
|
||||||
|
return [b.block_number for b in block_table[:block_idx + 1]]
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Can return non-empty result only with prefix caching enabled.
|
||||||
|
def get_common_computed_block_ids(self,
|
||||||
|
seq_group: SequenceGroup) -> List[int]:
|
||||||
|
if not self.enable_caching:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ids_list = [
|
||||||
|
self.get_all_block_ids_till_computed(seq)
|
||||||
|
for seq in iter(seq_group.seqs_dict.values())
|
||||||
|
]
|
||||||
|
return commonprefix([ids for ids in ids_list if ids != []])
|
||||||
|
|
||||||
|
# We only mark the last full block because with prefix caching,
|
||||||
|
# all blocks until the marked one are guaranteed to be computed.
|
||||||
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
|
if self.enable_caching:
|
||||||
|
for seq in seq_group.seqs_dict.values():
|
||||||
|
self.compute_last_full_block_in_seq(seq)
|
||||||
|
|||||||
161
vllm/core/evictor.py
Normal file
161
vllm/core/evictor.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import enum
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
|
|
||||||
|
from vllm.block import PhysicalTokenBlock
|
||||||
|
|
||||||
|
|
||||||
|
class EvictionPolicy(enum.Enum):
|
||||||
|
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||||
|
Evictor subclass.
|
||||||
|
"""
|
||||||
|
LRU = enum.auto()
|
||||||
|
FIFO = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Evictor(ABC):
|
||||||
|
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||||
|
handle eviction of freed PhysicalTokenBlocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __contains__(self, block_hash: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evict(self) -> PhysicalTokenBlock:
|
||||||
|
"""Runs the eviction algorithm and returns the evicted block"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, block: PhysicalTokenBlock):
|
||||||
|
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||||
|
"""Simply removes the block with the hash value block_hash from the
|
||||||
|
evictor. Caller is responsible for making sure that block_hash is contained
|
||||||
|
in the evictor before calling remove. Should be used to "bring back" blocks
|
||||||
|
that have been freed but not evicted yet.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractproperty
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LRUEvictor(Evictor):
|
||||||
|
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||||
|
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||||
|
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||||
|
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||||
|
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.free_table: Dict[int, PhysicalTokenBlock] = {}
|
||||||
|
|
||||||
|
def __contains__(self, block_hash: int) -> bool:
|
||||||
|
return block_hash in self.free_table
|
||||||
|
|
||||||
|
# TODO: The performance of this evict function can be optimized further.
|
||||||
|
def evict(self) -> PhysicalTokenBlock:
|
||||||
|
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
|
||||||
|
if len(free_blocks) == 0:
|
||||||
|
raise ValueError("No usable cache memory left")
|
||||||
|
|
||||||
|
# Find lowest timestamp
|
||||||
|
lowest_timestamp = free_blocks[0].last_accessed
|
||||||
|
for block in free_blocks:
|
||||||
|
if block.last_accessed < lowest_timestamp:
|
||||||
|
lowest_timestamp = block.last_accessed
|
||||||
|
|
||||||
|
# Find all blocks with the lowest timestamp
|
||||||
|
least_recent: List[PhysicalTokenBlock] = []
|
||||||
|
for block in free_blocks:
|
||||||
|
if block.last_accessed == lowest_timestamp:
|
||||||
|
least_recent.append(block)
|
||||||
|
|
||||||
|
# Find highest prefix count per block
|
||||||
|
highest_num_hashed_tokens = 0
|
||||||
|
for block in least_recent:
|
||||||
|
if block.num_hashed_tokens > highest_num_hashed_tokens:
|
||||||
|
highest_num_hashed_tokens = block.num_hashed_tokens
|
||||||
|
|
||||||
|
evicted_block: Optional[PhysicalTokenBlock] = None
|
||||||
|
|
||||||
|
# Find the first block with the lowest timestamp
|
||||||
|
for block in least_recent:
|
||||||
|
if block.num_hashed_tokens == highest_num_hashed_tokens:
|
||||||
|
evicted_block = block
|
||||||
|
break
|
||||||
|
|
||||||
|
assert evicted_block is not None
|
||||||
|
|
||||||
|
del self.free_table[evicted_block.block_hash]
|
||||||
|
|
||||||
|
evicted_block.computed = False
|
||||||
|
return evicted_block
|
||||||
|
|
||||||
|
def add(self, block: PhysicalTokenBlock):
|
||||||
|
self.free_table[block.block_hash] = block
|
||||||
|
|
||||||
|
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||||
|
if block_hash not in self.free_table:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to remove block that's not in the evictor")
|
||||||
|
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||||
|
del self.free_table[block_hash]
|
||||||
|
return block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
return len(self.free_table)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomEvictor(Evictor):
|
||||||
|
"""Evicts in a first-in-first-out order"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.free_table: Dict[int, PhysicalTokenBlock] = {}
|
||||||
|
|
||||||
|
def __contains__(self, block_hash: int) -> bool:
|
||||||
|
return block_hash in self.free_table
|
||||||
|
|
||||||
|
def evict(self) -> PhysicalTokenBlock:
|
||||||
|
if len(self.free_table) == 0:
|
||||||
|
raise ValueError("No usable cache memory left")
|
||||||
|
evicted_block = next(iter(self.free_table.values()))
|
||||||
|
evicted_block.computed = False
|
||||||
|
del self.free_table[evicted_block.block_hash]
|
||||||
|
return evicted_block
|
||||||
|
|
||||||
|
def add(self, block: PhysicalTokenBlock):
|
||||||
|
self.free_table[block.block_hash] = block
|
||||||
|
|
||||||
|
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||||
|
if block_hash not in self.free_table:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to remove block that's not in the evictor")
|
||||||
|
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||||
|
del self.free_table[block_hash]
|
||||||
|
return block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
return len(self.free_table)
|
||||||
|
|
||||||
|
|
||||||
|
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||||
|
if eviction_policy == EvictionPolicy.LRU:
|
||||||
|
return LRUEvictor()
|
||||||
|
elif eviction_policy == EvictionPolicy.FIFO:
|
||||||
|
return RandomEvictor()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||||
@ -10,7 +10,6 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceStatus)
|
SequenceGroupMetadata, SequenceStatus)
|
||||||
from vllm.prefix import PrefixPool
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -95,10 +94,8 @@ class Scheduler:
|
|||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||||
sliding_window=self.cache_config.sliding_window)
|
sliding_window=self.cache_config.sliding_window,
|
||||||
|
enable_caching=self.cache_config.enable_prefix_caching)
|
||||||
# Create the prefix pool to cache the prefixes.
|
|
||||||
self.prefix_pool = PrefixPool(self.cache_config.block_size)
|
|
||||||
|
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
self.waiting: Deque[SequenceGroup] = deque()
|
self.waiting: Deque[SequenceGroup] = deque()
|
||||||
@ -374,10 +371,12 @@ class Scheduler:
|
|||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
seq_id = seq.seq_id
|
seq_id = seq.seq_id
|
||||||
seq_data[seq_id] = seq.data
|
seq_data[seq_id] = seq.data
|
||||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||||
|
self.block_manager.access_all_blocks_in_seq(seq, now)
|
||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=seq_group.request_id,
|
request_id=seq_group.request_id,
|
||||||
@ -386,7 +385,8 @@ class Scheduler:
|
|||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
prefix=seq_group.prefix,
|
computed_block_nums=self.block_manager.
|
||||||
|
get_common_computed_block_ids(seq_group),
|
||||||
state=seq_group.state,
|
state=seq_group.state,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
@ -496,3 +496,6 @@ class Scheduler:
|
|||||||
blocks_to_swap_out.update(mapping)
|
blocks_to_swap_out.update(mapping)
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
seq.status = SequenceStatus.SWAPPED
|
seq.status = SequenceStatus.SWAPPED
|
||||||
|
|
||||||
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
|
self.block_manager.mark_blocks_as_computed(seq_group)
|
||||||
|
|||||||
@ -25,6 +25,7 @@ class EngineArgs:
|
|||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
|
enable_prefix_caching: bool = False
|
||||||
swap_space: int = 4 # GiB
|
swap_space: int = 4 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: Optional[int] = None
|
max_num_batched_tokens: Optional[int] = None
|
||||||
@ -173,6 +174,11 @@ class EngineArgs:
|
|||||||
default=EngineArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32, 128],
|
choices=[8, 16, 32, 128],
|
||||||
help='token block size')
|
help='token block size')
|
||||||
|
|
||||||
|
parser.add_argument('--enable-prefix-caching',
|
||||||
|
action='store_true',
|
||||||
|
help='Enables automatic prefix caching')
|
||||||
|
|
||||||
parser.add_argument('--seed',
|
parser.add_argument('--seed',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.seed,
|
default=EngineArgs.seed,
|
||||||
@ -293,7 +299,8 @@ class EngineArgs:
|
|||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space, self.kv_cache_dtype,
|
self.swap_space, self.kv_cache_dtype,
|
||||||
model_config.get_sliding_window())
|
model_config.get_sliding_window(),
|
||||||
|
self.enable_prefix_caching)
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray,
|
self.worker_use_ray,
|
||||||
|
|||||||
@ -225,7 +225,6 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
@ -245,7 +244,6 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prefix_pos=prefix_pos,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_workers_async(
|
async def _run_workers_async(
|
||||||
@ -422,7 +420,6 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
|
||||||
) -> AsyncStream:
|
) -> AsyncStream:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
shortened_prompt = prompt
|
shortened_prompt = prompt
|
||||||
@ -435,7 +432,6 @@ class AsyncLLMEngine:
|
|||||||
max_log_len]
|
max_log_len]
|
||||||
logger.info(f"Received request {request_id}: "
|
logger.info(f"Received request {request_id}: "
|
||||||
f"prompt: {shortened_prompt!r}, "
|
f"prompt: {shortened_prompt!r}, "
|
||||||
f"prefix_pos: {prefix_pos},"
|
|
||||||
f"sampling_params: {sampling_params}, "
|
f"sampling_params: {sampling_params}, "
|
||||||
f"prompt_token_ids: {shortened_token_ids}, "
|
f"prompt_token_ids: {shortened_token_ids}, "
|
||||||
f"lora_request: {lora_request}.")
|
f"lora_request: {lora_request}.")
|
||||||
@ -472,8 +468,7 @@ class AsyncLLMEngine:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request)
|
||||||
prefix_pos=prefix_pos)
|
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
@ -484,7 +479,6 @@ class AsyncLLMEngine:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncIterator[RequestOutput]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
@ -500,11 +494,6 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
prefix_pos: If not None, we use the given position as the prefix
|
|
||||||
position for each prompt. We will cache the prefix's KV
|
|
||||||
cache and reuse it for the next request with the same prefix.
|
|
||||||
This is an experimental feature, and may be replaced with
|
|
||||||
automatic prefix caching in the future.
|
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `RequestOutput` objects from the LLMEngine for the
|
The output `RequestOutput` objects from the LLMEngine for the
|
||||||
@ -565,7 +554,6 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prefix_pos=prefix_pos,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for request_output in stream:
|
async for request_output in stream:
|
||||||
|
|||||||
@ -415,7 +415,6 @@ class LLMEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
@ -432,11 +431,6 @@ class LLMEngine:
|
|||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
arrival_time: The arrival time of the request. If None, we use
|
arrival_time: The arrival time of the request. If None, we use
|
||||||
the current monotonic time.
|
the current monotonic time.
|
||||||
prefix_pos: If not None, we use the given position as the prefix
|
|
||||||
position for each prompt. We will cache the prefix's KV
|
|
||||||
cache and reuse it for the next request with the same prefix.
|
|
||||||
This is an experimental feature, and may be replaced with
|
|
||||||
automatic prefix caching in the future.
|
|
||||||
|
|
||||||
Details:
|
Details:
|
||||||
- Set arrival_time to the current time if it is None.
|
- Set arrival_time to the current time if it is None.
|
||||||
@ -479,18 +473,13 @@ class LLMEngine:
|
|||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||||
lora_request)
|
lora_request)
|
||||||
|
|
||||||
# Check whether the input specifies prefix
|
|
||||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
|
||||||
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
|
|
||||||
if lora_request else 0) if prefix_pos is not None else None
|
|
||||||
|
|
||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
arrival_time, lora_request, prefix)
|
arrival_time, lora_request)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
@ -752,6 +741,13 @@ class LLMEngine:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
# Update the scheduled sequence groups with the model outputs.
|
# Update the scheduled sequence groups with the model outputs.
|
||||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||||
|
|
||||||
|
# If prefix caching is enabled, mark all blocks in the sequence groups
|
||||||
|
# as completed so that future requests don't attempt to recompute them
|
||||||
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
for seq_group in scheduled_seq_groups:
|
||||||
|
self.scheduler.mark_blocks_as_computed(seq_group)
|
||||||
|
|
||||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||||
self._process_sequence_group_outputs(seq_group, outputs)
|
self._process_sequence_group_outputs(seq_group, outputs)
|
||||||
|
|
||||||
@ -768,12 +764,6 @@ class LLMEngine:
|
|||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
request_outputs.append(request_output)
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
# Update prefix state, now all the uncomputed prefixes are computed.
|
|
||||||
for seq_group in scheduled_seq_groups:
|
|
||||||
if (seq_group.prefix is not None and seq_group.prefix.allocated
|
|
||||||
and not seq_group.prefix.computed):
|
|
||||||
seq_group.prefix.computed = True
|
|
||||||
|
|
||||||
# Log stats.
|
# Log stats.
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||||
|
|||||||
@ -39,15 +39,11 @@ async def generate(request: Request) -> Response:
|
|||||||
"""
|
"""
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
prompt = request_dict.pop("prompt")
|
prompt = request_dict.pop("prompt")
|
||||||
prefix_pos = request_dict.pop("prefix_pos", None)
|
|
||||||
stream = request_dict.pop("stream", False)
|
stream = request_dict.pop("stream", False)
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
results_generator = engine.generate(prompt,
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
sampling_params,
|
|
||||||
request_id,
|
|
||||||
prefix_pos=prefix_pos)
|
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
|
|||||||
@ -124,7 +124,6 @@ class LLM:
|
|||||||
prompts: Optional[Union[str, List[str]]] = None,
|
prompts: Optional[Union[str, List[str]]] = None,
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
prefix_pos: Optional[Union[int, List[int]]] = None,
|
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
@ -140,11 +139,6 @@ class LLM:
|
|||||||
None, we use the default sampling parameters.
|
None, we use the default sampling parameters.
|
||||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
prefix_pos: If not None, we use the given position as the prefix
|
|
||||||
position for each prompt. We will cache the prefix's KV
|
|
||||||
cache and reuse it for the next request with the same prefix.
|
|
||||||
This is an experimental feature, and may be replaced with
|
|
||||||
automatic prefix caching in the future.
|
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
|
||||||
@ -171,14 +165,12 @@ class LLM:
|
|||||||
prompt_token_ids)
|
prompt_token_ids)
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt = prompts[i] if prompts is not None else None
|
prompt = prompts[i] if prompts is not None else None
|
||||||
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
|
||||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||||
i]
|
i]
|
||||||
self._add_request(prompt,
|
self._add_request(prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
token_ids,
|
token_ids,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request)
|
||||||
prefix_pos=prefix_pos_i)
|
|
||||||
return self._run_engine(use_tqdm)
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
@ -187,15 +179,13 @@ class LLM:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(request_id,
|
self.llm_engine.add_request(request_id,
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
prompt_token_ids,
|
prompt_token_ids,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request)
|
||||||
prefix_pos=prefix_pos)
|
|
||||||
|
|
||||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
|
|||||||
@ -1,87 +0,0 @@
|
|||||||
from typing import Dict, List, Sequence, Tuple, Optional
|
|
||||||
|
|
||||||
from vllm.block import BlockTable
|
|
||||||
|
|
||||||
|
|
||||||
class Prefix:
|
|
||||||
"""Data and states associated with a prefix of prompt tokens for multiple
|
|
||||||
sequence groups.
|
|
||||||
|
|
||||||
NOTE: This feature is experimental and may be replaced with automatic
|
|
||||||
prefix caching in the future.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The token ids of the prefix.
|
|
||||||
block_size: The block size of the executed model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token_ids: Sequence[int],
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.token_ids = tuple(token_ids)
|
|
||||||
self.block_size = block_size
|
|
||||||
self.length = len(token_ids)
|
|
||||||
self.hash = hash(token_ids)
|
|
||||||
assert self.length % block_size == 0
|
|
||||||
self.block_table: Optional[BlockTable] = None
|
|
||||||
self.computed = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allocated(self) -> bool:
|
|
||||||
return self.block_table is not None
|
|
||||||
|
|
||||||
def get_num_blocks(self) -> int:
|
|
||||||
return self.length // self.block_size
|
|
||||||
|
|
||||||
def get_block_numbers(self) -> List[int]:
|
|
||||||
return [block.block_number for block in self.block_table]
|
|
||||||
|
|
||||||
def get_length(self) -> int:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return self.hash
|
|
||||||
|
|
||||||
def set_block_table(self, block_table: BlockTable) -> None:
|
|
||||||
self.block_table = block_table.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixPool:
|
|
||||||
"""Manages all the prompt prefixes.
|
|
||||||
|
|
||||||
NOTE: This feature is experimental and may be replaced with automatic
|
|
||||||
prefix caching in the future.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
block_size: The block size of the executed model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prefixes: A list of all the prefixes.
|
|
||||||
block_size: The block size of the executed model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
# TODO(zhuohan): Add a capacity limit to the prefix pool.
|
|
||||||
self.prefixes: Dict[int, Prefix] = {}
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
|
|
||||||
new_length = len(token_ids) // self.block_size * self.block_size
|
|
||||||
return tuple(token_ids[:new_length])
|
|
||||||
|
|
||||||
def add_or_get_prefix(self, token_ids: Sequence[int],
|
|
||||||
lora_int_id: int) -> Optional[Prefix]:
|
|
||||||
token_ids = self._truncate_token_ids(token_ids)
|
|
||||||
if len(token_ids) == 0:
|
|
||||||
# Prefix is empty.
|
|
||||||
return None
|
|
||||||
prefix = Prefix(token_ids, self.block_size)
|
|
||||||
prefix_hash = hash((prefix, lora_int_id))
|
|
||||||
if prefix_hash not in self.prefixes:
|
|
||||||
self.prefixes[prefix_hash] = prefix
|
|
||||||
return self.prefixes[prefix_hash]
|
|
||||||
@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from vllm.block import LogicalTokenBlock
|
from vllm.block import LogicalTokenBlock
|
||||||
from vllm.prefix import Prefix
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
@ -161,6 +160,16 @@ class Sequence:
|
|||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
|
# TODO The current hashing function is O(L^2). We should optimize this in
|
||||||
|
# the future.
|
||||||
|
def hash_of_block(self, logical_idx: int) -> int:
|
||||||
|
# Compute the number of tokens in the sequence
|
||||||
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||||
|
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
||||||
|
|
||||||
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||||
|
return logical_idx * self.block_size + self.block_size
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
block_number=len(self.logical_token_blocks),
|
block_number=len(self.logical_token_blocks),
|
||||||
@ -265,7 +274,6 @@ class SequenceGroup:
|
|||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
arrival_time: The arrival time of the request.
|
arrival_time: The arrival time of the request.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
prefix: The prefix of the prompt of the sequence group.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -275,7 +283,6 @@ class SequenceGroup:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix: Optional[Prefix] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
@ -286,7 +293,6 @@ class SequenceGroup:
|
|||||||
first_token_time=None,
|
first_token_time=None,
|
||||||
time_in_queue=None)
|
time_in_queue=None)
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prefix: Optional[Prefix] = prefix
|
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
self.state = SequenceGroupState()
|
self.state = SequenceGroupState()
|
||||||
|
|
||||||
@ -302,6 +308,10 @@ class SequenceGroup:
|
|||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_size(self) -> int:
|
||||||
|
return next(iter(self.seqs_dict.values())).block_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
@ -408,7 +418,6 @@ class SequenceGroupMetadata:
|
|||||||
numbers)
|
numbers)
|
||||||
state: Internal state tied to this sequence group.
|
state: Internal state tied to this sequence group.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
prefix: The prefix of the prompt of the sequence group.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -419,7 +428,7 @@ class SequenceGroupMetadata:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]],
|
block_tables: Dict[int, List[int]],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix: Optional[Prefix] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
state: Optional[SequenceGroupState] = None,
|
state: Optional[SequenceGroupState] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@ -428,7 +437,7 @@ class SequenceGroupMetadata:
|
|||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prefix = prefix
|
self.computed_block_nums = computed_block_nums
|
||||||
self.state = SequenceGroupState() if state is None else state
|
self.state = SequenceGroupState() if state is None else state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -145,33 +145,37 @@ class ModelRunner:
|
|||||||
prompt_tokens = seq_data.get_token_ids()
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
prefix_len = 0
|
computed_len = 0
|
||||||
prefix = seq_group_metadata.prefix
|
|
||||||
if prefix is not None and prefix.computed:
|
# NOTE: This only works for oooooooxxx style attention.
|
||||||
prefix_len = prefix.get_length()
|
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||||
prompt_tokens = prompt_tokens[prefix_len:]
|
if computed_block_nums is not None and len(
|
||||||
prefix_block_tables.append(prefix.get_block_numbers())
|
computed_block_nums) > 0 and self.sliding_window is None:
|
||||||
|
# Prefix is not supported with sliding_window
|
||||||
|
computed_len = len(computed_block_nums) * self.block_size
|
||||||
|
prompt_tokens = prompt_tokens[computed_len:]
|
||||||
|
prefix_block_tables.append(computed_block_nums)
|
||||||
else:
|
else:
|
||||||
prefix_block_tables.append([])
|
prefix_block_tables.append([])
|
||||||
# actual prompt lens
|
# actual prompt lens
|
||||||
context_lens.append(prefix_len)
|
context_lens.append(computed_len)
|
||||||
subquery_lens.append(prompt_len - prefix_len)
|
subquery_lens.append(prompt_len - computed_len)
|
||||||
|
|
||||||
input_tokens.append(prompt_tokens)
|
input_tokens.append(prompt_tokens)
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
# is always the first token in the sequence.
|
# is always the first token in the sequence.
|
||||||
input_positions.append(
|
input_positions.append(
|
||||||
list(range(prefix_len, prefix_len + len(prompt_tokens))))
|
list(range(computed_len, computed_len + len(prompt_tokens))))
|
||||||
|
|
||||||
lora_id = seq_group_metadata.lora_int_id
|
lora_id = seq_group_metadata.lora_int_id
|
||||||
|
|
||||||
if lora_id > 0:
|
if lora_id > 0:
|
||||||
lora_requests.add(seq_group_metadata.lora_request)
|
lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
|
||||||
lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
|
lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
|
||||||
lora_prompt_mapping.extend(
|
lora_prompt_mapping.extend(
|
||||||
[lora_id] *
|
[lora_id] *
|
||||||
(prompt_len - prefix_len
|
(prompt_len - computed_len
|
||||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
if seq_group_metadata.block_tables is None:
|
||||||
@ -190,11 +194,11 @@ class ModelRunner:
|
|||||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
assert prefix_len == 0, (
|
assert computed_len == 0, (
|
||||||
"Prefix caching is currently not supported with "
|
"Prefix caching is currently not supported with "
|
||||||
"sliding window attention")
|
"sliding window attention")
|
||||||
start_idx = max(0, prompt_len - self.sliding_window)
|
start_idx = max(0, prompt_len - self.sliding_window)
|
||||||
for i in range(prefix_len, prompt_len):
|
for i in range(computed_len, prompt_len):
|
||||||
if i < start_idx:
|
if i < start_idx:
|
||||||
slot_mapping[-1].append(_PAD_SLOT_ID)
|
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user