[Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837)

This commit is contained in:
Cade Daniel 2024-04-09 11:44:15 -07:00 committed by GitHub
parent 6d592eb430
commit e7c7067b45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 451 additions and 275 deletions

View File

@ -16,7 +16,7 @@ from vllm import SamplingParams
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"forced_num_gpu_blocks": 5 * (64 + 1),
"num_gpu_blocks_override": 5 * (64 + 1),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size
"forced_num_gpu_blocks": 2 * (8 + 1),
"num_gpu_blocks_override": 2 * (8 + 1),
},
{
"block_size": 8,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"forced_num_gpu_blocks": 2 * (16 + 1),
"num_gpu_blocks_override": 2 * (16 + 1),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{

View File

@ -3,8 +3,8 @@ import random
import tempfile
from unittest.mock import patch
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files):
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
cache_config=CacheConfig(block_size=16,
gpu_memory_utilization=1.,
swap_space=0,
cache_dtype="auto"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,

View File

@ -512,8 +512,8 @@ def test_init_device():
@torch.inference_mode()
def test_init_cache_engine():
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
def test_initialize_cache():
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
@ -525,12 +525,11 @@ def test_init_cache_engine():
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
cache_config = MagicMock()
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
worker.init_cache_engine(cache_config)
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
target_worker.init_cache_engine.assert_called_once_with(cache_config)
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.initialize_cache.assert_called_once_with(**kwargs)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@ -538,10 +537,10 @@ def test_init_cache_engine():
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_profile_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.profile_num_available_blocks.return_value = (
target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
# These values do not directly impact the adjusted block size calculation,
# so they can be fixed.
gpu_memory_utilization = 0.9
cpu_swap_space = 100
block_size = 16
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
target_worker.profile_num_available_blocks.assert_called_once_with(
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
target_worker.determine_num_available_blocks.assert_called_once()
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(

View File

@ -117,6 +117,7 @@ def create_worker(cls: type,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@ -128,8 +129,9 @@ def create_worker(cls: type,
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker

View File

@ -11,8 +11,8 @@ def test_swap() -> None:
dtype="half",
load_format="dummy")
engine_config = engine_args.create_engine_config()
engine_config.cache_config.num_gpu_blocks = 100
engine_config.cache_config.num_cpu_blocks = 100
engine_config.cache_config.num_gpu_blocks = 1000
engine_config.cache_config.num_cpu_blocks = 1000
# Create the worker.
distributed_init_method = get_distributed_init_method(
@ -22,6 +22,7 @@ def test_swap() -> None:
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@ -31,8 +32,9 @@ def test_swap() -> None:
# Initialize the worker.
worker.init_device()
worker.load_model()
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache

View File

@ -334,7 +334,7 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
"""
@ -344,14 +344,14 @@ class CacheConfig:
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
forced_num_gpu_blocks: Optional[int] = None,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.forced_num_gpu_blocks = forced_num_gpu_blocks
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching

View File

@ -55,7 +55,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
# Related to Vision-language models such as llava
@ -246,7 +246,7 @@ class EngineArgs:
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--forced-num-gpu-blocks',
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
@ -426,7 +426,7 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
self.forced_num_gpu_blocks,
self.num_gpu_blocks_override,
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(

View File

@ -127,6 +127,8 @@ class LLMEngine:
speculative_config=speculative_config,
)
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
@ -178,6 +180,26 @@ class LLMEngine:
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
num_gpu_blocks, num_cpu_blocks = (
self.model_executor.determine_num_available_blocks())
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(f"Overriding {num_gpu_blocks=} with "
f"{num_gpu_blocks_override=}")
num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
@classmethod
def from_engine_args(
cls,

View File

@ -35,7 +35,6 @@ class CPUExecutor(ExecutorBase):
# Instantiate the worker and load the model to CPU.
self._init_worker()
self._init_cache()
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
@ -46,10 +45,11 @@ class CPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@ -60,35 +60,21 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_cache(self) -> None:
num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num(
block_size=self.cache_config.block_size,
cache_space=self.cache_config.cpu_kvcache_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info(f"# CPU blocks: {num_cpu_blocks}")
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore
self.cache_config.num_cpu_blocks = 0 # type: ignore
# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -104,13 +90,13 @@ class CPUExecutor(ExecutorBase):
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
raise NotImplementedError("LoRA is not implemented for cpu backend.")
return self.driver_worker.list_loras()
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as

View File

@ -30,6 +30,29 @@ class ExecutorBase(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -4,7 +4,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -41,9 +40,6 @@ class GPUExecutor(ExecutorBase):
# Instantiate the worker and load the model to GPU.
self._init_worker()
# Profile the memory usage and initialize the cache.
self._init_cache()
def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
@ -55,61 +51,37 @@ class GPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks, num_cpu_blocks = (
self.driver_worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.
gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
))
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self.driver_worker.warm_up_model()
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -25,7 +25,6 @@ class NeuronExecutor(ExecutorBase):
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
@ -33,12 +32,6 @@ class NeuronExecutor(ExecutorBase):
assert (not speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
self.cache_config.num_cpu_blocks = 0
# Instantiate the worker and load the model to the device.
self._init_worker()
@ -54,6 +47,18 @@ class NeuronExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
@ -68,16 +73,13 @@ class NeuronExecutor(ExecutorBase):
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
raise NotImplementedError(
"LoRA is not implemented for neuron backend.")
return self.driver_worker.list_loras()
def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as

View File

@ -10,7 +10,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -65,9 +64,6 @@ class RayGPUExecutor(ExecutorBase):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self._init_cache()
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
@ -154,8 +150,8 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
vision_language_config = copy.deepcopy(self.vision_language_config)
kv_cache_dtype = self.cache_config.cache_dtype
# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
@ -165,32 +161,32 @@ class RayGPUExecutor(ExecutorBase):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
kv_cache_dtype=kv_cache_dtype,
))
# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
driver_local_rank,
driver_rank,
distributed_init_method,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=driver_local_rank,
rank=driver_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)
@ -201,35 +197,18 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers,
)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
)
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
@ -237,26 +216,25 @@ class RayGPUExecutor(ExecutorBase):
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -1,13 +0,0 @@
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")

View File

@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
@ -15,9 +14,10 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class SpecDecodeWorker:
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding.
Speculative decoding reduces decoding per-token latency by using a proposal
@ -94,10 +94,7 @@ class SpecDecodeWorker:
device=self.device,
vocab_size=self._vocab_size)
def profile_num_available_blocks(self, block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str) -> Tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
@ -106,27 +103,26 @@ class SpecDecodeWorker:
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.profile_num_available_blocks(
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
self.scorer_worker.determine_num_available_blocks())
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
self.proposer_worker.get_cache_block_size_bytes())
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig):
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the cache engine of the scorer and proposer workers.
"""
self.scorer_worker.init_cache_engine(cache_config)
self.proposer_worker.init_cache_engine(cache_config)
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
@torch.inference_mode()
def execute_model(
@ -351,6 +347,16 @@ class SpecDecodeWorker:
def device(self):
return self.scorer_worker.device
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes.
This function is only used to compose workers within a SpecDecodeWorker.
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
undefined for now, although it could be implemented in the future.
See https://arxiv.org/abs/2308.04623.
"""
raise NotImplementedError
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
proposer_cache_block_size_bytes: int,

View File

@ -82,8 +82,7 @@ class CacheEngine:
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
@ -91,13 +90,13 @@ class CacheEngine:
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto":
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = _get_dtype_size(dtype)
return dtype_size * total

View File

@ -17,6 +17,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
logger = init_logger(__name__)
@ -112,7 +113,7 @@ class CPUCacheEngine:
return dtype_size * total
class CPUWorker:
class CPUWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
@ -127,6 +128,7 @@ class CPUWorker:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
@ -138,6 +140,7 @@ class CPUWorker:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
@ -154,8 +157,7 @@ class CPUWorker:
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
# initialize_cache.
self.cache_engine = None
self.cpu_cache = None
@ -167,28 +169,70 @@ class CPUWorker:
def load_model(self):
self.model_runner.load_model()
def get_cpu_cache_block_num(
self,
block_size: int,
cache_space: int,
cache_dtype: str,
) -> int:
"""
Args:
block_size: The size of the cache block.
cache_space: The size of the CPU KV cache space in bytes.
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = CPUCacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config)
num_cpu_blocks = int(cache_space // cache_block_size)
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
return num_cpu_blocks
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config,
self.model_config,
self.parallel_config,
@ -264,3 +308,10 @@ class CPUWorker:
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(
self.cache_config.block_size, self.cache_config.cache_dtype,
self.model_config, self.parallel_config)

View File

@ -4,14 +4,15 @@ from typing import List, Optional
import torch
import torch.distributed
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NeuronWorker:
class NeuronWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
@ -21,11 +22,13 @@ class NeuronWorker:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config)
@ -37,6 +40,35 @@ class NeuronWorker:
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks = self.scheduler_config.max_num_seqs
# Swap not yet supported with Neuron backend.
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert num_cpu_blocks == 0
assert num_gpu_blocks == self.scheduler_config.max_num_seqs
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@torch.inference_mode()
def execute_model(
self,
@ -50,3 +82,10 @@ class NeuronWorker:
output = self.model_runner.execute_model(seq_group_metadata_list)
return output
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise NotImplementedError

View File

@ -19,9 +19,10 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
class Worker:
class Worker(WorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
@ -35,18 +36,19 @@ class Worker:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
@ -66,12 +68,11 @@ class Worker:
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
# initialize_cache.
self.cache_engine = None
self.gpu_cache = None
@ -107,20 +108,17 @@ class Worker:
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
@ -141,12 +139,12 @@ class Worker:
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes(
block_size, cache_dtype)
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
@ -155,14 +153,30 @@ class Worker:
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(num_gpu_blocks,
self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
@ -239,11 +253,10 @@ class Worker:
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self, block_size: int,
cache_dtype: str) -> int:
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
return CacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
@ -300,3 +313,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
f"{compute_capability[0]}.{compute_capability[1]}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")

View File

@ -0,0 +1,83 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
"""
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes() -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> List[int]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> List[int]:
raise ValueError(f"{type(self)} does not support LoRA")