mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:25:40 +08:00
[Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837)
This commit is contained in:
parent
6d592eb430
commit
e7c7067b45
@ -16,7 +16,7 @@ from vllm import SamplingParams
|
|||||||
|
|
||||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||||
"block_size": 16,
|
"block_size": 16,
|
||||||
"forced_num_gpu_blocks": 5 * (64 + 1),
|
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||||
@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
|
|||||||
|
|
||||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||||
# Note 8 = 128/block_size
|
# Note 8 = 128/block_size
|
||||||
"forced_num_gpu_blocks": 2 * (8 + 1),
|
"num_gpu_blocks_override": 2 * (8 + 1),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"block_size": 8,
|
"block_size": 8,
|
||||||
|
|
||||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||||
# Note 16 = 128/block_size
|
# Note 16 = 128/block_size
|
||||||
"forced_num_gpu_blocks": 2 * (16 + 1),
|
"num_gpu_blocks_override": 2 * (16 + 1),
|
||||||
}
|
}
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||||
|
|||||||
@ -3,8 +3,8 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.lora.models import LoRAMapping
|
from vllm.lora.models import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
parallel_config=ParallelConfig(1, 1, False),
|
parallel_config=ParallelConfig(1, 1, False),
|
||||||
scheduler_config=SchedulerConfig(32, 32, 32),
|
scheduler_config=SchedulerConfig(32, 32, 32),
|
||||||
device_config=DeviceConfig("cuda"),
|
device_config=DeviceConfig("cuda"),
|
||||||
|
cache_config=CacheConfig(block_size=16,
|
||||||
|
gpu_memory_utilization=1.,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto"),
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||||
|
|||||||
@ -512,8 +512,8 @@ def test_init_device():
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_init_cache_engine():
|
def test_initialize_cache():
|
||||||
"""Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
|
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||||
workers.
|
workers.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
@ -525,12 +525,11 @@ def test_init_cache_engine():
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
cache_config = MagicMock()
|
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||||
|
worker.initialize_cache(**kwargs)
|
||||||
|
|
||||||
worker.init_cache_engine(cache_config)
|
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||||
|
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||||
draft_worker.init_cache_engine.assert_called_once_with(cache_config)
|
|
||||||
target_worker.init_cache_engine.assert_called_once_with(cache_config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||||
@ -538,10 +537,10 @@ def test_init_cache_engine():
|
|||||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_profile_num_available_blocks(available_gpu_blocks: int,
|
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||||
available_cpu_blocks: int,
|
available_cpu_blocks: int,
|
||||||
target_cache_block_size_bytes: int,
|
target_cache_block_size_bytes: int,
|
||||||
draft_kv_size_bytes: int):
|
draft_kv_size_bytes: int):
|
||||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||||
split the blocks between proposer and scorer worker.
|
split the blocks between proposer and scorer worker.
|
||||||
@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
|
|||||||
rejection_sampler.token_id_dtype = torch.int64
|
rejection_sampler.token_id_dtype = torch.int64
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
target_worker.profile_num_available_blocks.return_value = (
|
target_worker.determine_num_available_blocks.return_value = (
|
||||||
available_gpu_blocks, available_cpu_blocks)
|
available_gpu_blocks, available_cpu_blocks)
|
||||||
target_worker.get_cache_block_size_bytes.return_value = (
|
target_worker.get_cache_block_size_bytes.return_value = (
|
||||||
target_cache_block_size_bytes)
|
target_cache_block_size_bytes)
|
||||||
@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
# These values do not directly impact the adjusted block size calculation,
|
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||||
# so they can be fixed.
|
|
||||||
gpu_memory_utilization = 0.9
|
|
||||||
cpu_swap_space = 100
|
|
||||||
block_size = 16
|
|
||||||
|
|
||||||
num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
|
target_worker.determine_num_available_blocks.assert_called_once()
|
||||||
block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")
|
|
||||||
|
|
||||||
target_worker.profile_num_available_blocks.assert_called_once_with(
|
|
||||||
block_size, gpu_memory_utilization, cpu_swap_space, "auto")
|
|
||||||
assert num_cpu_blocks == available_cpu_blocks
|
assert num_cpu_blocks == available_cpu_blocks
|
||||||
|
|
||||||
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||||
|
|||||||
@ -117,6 +117,7 @@ def create_worker(cls: type,
|
|||||||
parallel_config=engine_config.parallel_config,
|
parallel_config=engine_config.parallel_config,
|
||||||
scheduler_config=engine_config.scheduler_config,
|
scheduler_config=engine_config.scheduler_config,
|
||||||
device_config=engine_config.device_config,
|
device_config=engine_config.device_config,
|
||||||
|
cache_config=engine_config.cache_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -128,8 +129,9 @@ def create_worker(cls: type,
|
|||||||
|
|
||||||
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
engine_config.cache_config.num_cpu_blocks = 0
|
engine_config.cache_config.num_cpu_blocks = 0
|
||||||
worker.init_cache_engine(engine_config.cache_config)
|
worker.initialize_cache(
|
||||||
worker.warm_up_model()
|
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||||
|
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||||
|
|
||||||
return worker
|
return worker
|
||||||
|
|
||||||
|
|||||||
@ -11,8 +11,8 @@ def test_swap() -> None:
|
|||||||
dtype="half",
|
dtype="half",
|
||||||
load_format="dummy")
|
load_format="dummy")
|
||||||
engine_config = engine_args.create_engine_config()
|
engine_config = engine_args.create_engine_config()
|
||||||
engine_config.cache_config.num_gpu_blocks = 100
|
engine_config.cache_config.num_gpu_blocks = 1000
|
||||||
engine_config.cache_config.num_cpu_blocks = 100
|
engine_config.cache_config.num_cpu_blocks = 1000
|
||||||
|
|
||||||
# Create the worker.
|
# Create the worker.
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
@ -22,6 +22,7 @@ def test_swap() -> None:
|
|||||||
parallel_config=engine_config.parallel_config,
|
parallel_config=engine_config.parallel_config,
|
||||||
scheduler_config=engine_config.scheduler_config,
|
scheduler_config=engine_config.scheduler_config,
|
||||||
device_config=engine_config.device_config,
|
device_config=engine_config.device_config,
|
||||||
|
cache_config=engine_config.cache_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -31,8 +32,9 @@ def test_swap() -> None:
|
|||||||
# Initialize the worker.
|
# Initialize the worker.
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
worker.load_model()
|
worker.load_model()
|
||||||
worker.init_cache_engine(engine_config.cache_config)
|
worker.initialize_cache(
|
||||||
worker.warm_up_model()
|
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||||
|
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||||
|
|
||||||
# Randomly initialize the cache.
|
# Randomly initialize the cache.
|
||||||
gpu_cache = worker.cache_engine.gpu_cache
|
gpu_cache = worker.cache_engine.gpu_cache
|
||||||
|
|||||||
@ -334,7 +334,7 @@ class CacheConfig:
|
|||||||
vLLM execution.
|
vLLM execution.
|
||||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||||
cache_dtype: Data type for kv cache storage.
|
cache_dtype: Data type for kv cache storage.
|
||||||
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
|
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
||||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
profiled num_gpu_blocks if specified. Does nothing if None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -344,14 +344,14 @@ class CacheConfig:
|
|||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
cache_dtype: str,
|
cache_dtype: str,
|
||||||
forced_num_gpu_blocks: Optional[int] = None,
|
num_gpu_blocks_override: Optional[int] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
enable_prefix_caching: bool = False,
|
enable_prefix_caching: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
self.swap_space_bytes = swap_space * _GB
|
self.swap_space_bytes = swap_space * _GB
|
||||||
self.forced_num_gpu_blocks = forced_num_gpu_blocks
|
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||||
self.cache_dtype = cache_dtype
|
self.cache_dtype = cache_dtype
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.enable_prefix_caching = enable_prefix_caching
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
|
|||||||
@ -55,7 +55,7 @@ class EngineArgs:
|
|||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
forced_num_gpu_blocks: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
|
|
||||||
# Related to Vision-language models such as llava
|
# Related to Vision-language models such as llava
|
||||||
@ -246,7 +246,7 @@ class EngineArgs:
|
|||||||
'the model executor, which can range from 0 to 1.'
|
'the model executor, which can range from 0 to 1.'
|
||||||
'If unspecified, will use the default value of 0.9.')
|
'If unspecified, will use the default value of 0.9.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--forced-num-gpu-blocks',
|
'--num-gpu-blocks-override',
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help='If specified, ignore GPU profiling result and use this number'
|
help='If specified, ignore GPU profiling result and use this number'
|
||||||
@ -426,7 +426,7 @@ class EngineArgs:
|
|||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space, self.kv_cache_dtype,
|
self.swap_space, self.kv_cache_dtype,
|
||||||
self.forced_num_gpu_blocks,
|
self.num_gpu_blocks_override,
|
||||||
model_config.get_sliding_window(),
|
model_config.get_sliding_window(),
|
||||||
self.enable_prefix_caching)
|
self.enable_prefix_caching)
|
||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
|
|||||||
@ -127,6 +127,8 @@ class LLMEngine:
|
|||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._initialize_kv_caches()
|
||||||
|
|
||||||
# If usage stat is enabled, collect relevant info.
|
# If usage stat is enabled, collect relevant info.
|
||||||
if is_usage_stats_enabled():
|
if is_usage_stats_enabled():
|
||||||
from vllm.model_executor.model_loader import (
|
from vllm.model_executor.model_loader import (
|
||||||
@ -178,6 +180,26 @@ class LLMEngine:
|
|||||||
labels=dict(model_name=model_config.model))
|
labels=dict(model_name=model_config.model))
|
||||||
self.stat_logger.info("cache_config", self.cache_config)
|
self.stat_logger.info("cache_config", self.cache_config)
|
||||||
|
|
||||||
|
def _initialize_kv_caches(self) -> None:
|
||||||
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
|
The workers will determine the number of blocks in both the GPU cache
|
||||||
|
and the swap CPU cache.
|
||||||
|
"""
|
||||||
|
num_gpu_blocks, num_cpu_blocks = (
|
||||||
|
self.model_executor.determine_num_available_blocks())
|
||||||
|
|
||||||
|
if self.cache_config.num_gpu_blocks_override is not None:
|
||||||
|
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
||||||
|
logger.info(f"Overriding {num_gpu_blocks=} with "
|
||||||
|
f"{num_gpu_blocks_override=}")
|
||||||
|
num_gpu_blocks = num_gpu_blocks_override
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(
|
def from_engine_args(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -35,7 +35,6 @@ class CPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
# Instantiate the worker and load the model to CPU.
|
# Instantiate the worker and load the model to CPU.
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
self._init_cache()
|
|
||||||
|
|
||||||
def _init_worker(self):
|
def _init_worker(self):
|
||||||
from vllm.worker.cpu_worker import CPUWorker
|
from vllm.worker.cpu_worker import CPUWorker
|
||||||
@ -46,10 +45,11 @@ class CPUExecutor(ExecutorBase):
|
|||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_ip(), get_open_port())
|
get_ip(), get_open_port())
|
||||||
self.driver_worker = CPUWorker(
|
self.driver_worker = CPUWorker(
|
||||||
self.model_config,
|
model_config=self.model_config,
|
||||||
self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
self.device_config,
|
device_config=self.device_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -60,35 +60,21 @@ class CPUExecutor(ExecutorBase):
|
|||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num(
|
"""Determine the number of available KV blocks by invoking the
|
||||||
block_size=self.cache_config.block_size,
|
underlying worker.
|
||||||
cache_space=self.cache_config.cpu_kvcache_space_bytes,
|
"""
|
||||||
cache_dtype=self.cache_config.cache_dtype,
|
return self.driver_worker.determine_num_available_blocks()
|
||||||
)
|
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache by invoking the underlying worker.
|
||||||
|
"""
|
||||||
|
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||||
|
# greater than one. We could log in the engine, but not all executors
|
||||||
|
# have GPUs.
|
||||||
logger.info(f"# CPU blocks: {num_cpu_blocks}")
|
logger.info(f"# CPU blocks: {num_cpu_blocks}")
|
||||||
if num_cpu_blocks <= 0:
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
raise ValueError("No available memory for the cache blocks. "
|
|
||||||
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
|
||||||
"initializing the engine.")
|
|
||||||
|
|
||||||
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
|
||||||
if self.model_config.max_model_len > max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
|
||||||
"is larger than the maximum number of tokens that can be "
|
|
||||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
|
||||||
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
|
||||||
"initializing the engine.")
|
|
||||||
|
|
||||||
# Note: To reuse the cache management procedure,
|
|
||||||
# use cpu cache as 'gpu cache'.
|
|
||||||
self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore
|
|
||||||
self.cache_config.num_cpu_blocks = 0 # type: ignore
|
|
||||||
|
|
||||||
# Initialize the cache.
|
|
||||||
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
|
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
@ -104,13 +90,13 @@ class CPUExecutor(ExecutorBase):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
return self.driver_worker.add_lora(lora_request)
|
||||||
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> List[int]:
|
||||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# CPUExecutor will always be healthy as long as
|
# CPUExecutor will always be healthy as long as
|
||||||
|
|||||||
@ -30,6 +30,29 @@ class ExecutorBase(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
|
"""Determine the number of available blocks for the GPU KV cache and
|
||||||
|
swappable CPU KV cache.
|
||||||
|
|
||||||
|
Normally, this should simply delegate to the underlying Worker. Some
|
||||||
|
ExecutorBase may require modification of the result, e.g. to ensure the
|
||||||
|
selected cache sizes are compatible with all workers.
|
||||||
|
|
||||||
|
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||||
|
are blocks that are "active" on the device and can be appended to.
|
||||||
|
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||||
|
appended to.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache with the given size in blocks.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(self,
|
def execute_model(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
|||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.executor.utils import check_block_size_valid
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
@ -41,9 +40,6 @@ class GPUExecutor(ExecutorBase):
|
|||||||
# Instantiate the worker and load the model to GPU.
|
# Instantiate the worker and load the model to GPU.
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
|
|
||||||
# Profile the memory usage and initialize the cache.
|
|
||||||
self._init_cache()
|
|
||||||
|
|
||||||
def _init_worker(self):
|
def _init_worker(self):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
@ -55,61 +51,37 @@ class GPUExecutor(ExecutorBase):
|
|||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_ip(), get_open_port())
|
get_ip(), get_open_port())
|
||||||
self.driver_worker = Worker(
|
self.driver_worker = Worker(
|
||||||
self.model_config,
|
model_config=self.model_config,
|
||||||
self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
self.device_config,
|
device_config=self.device_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
"""Profiles the memory usage and initializes the KV cache.
|
"""Determine the number of available KV blocks by invoking the
|
||||||
|
underlying worker.
|
||||||
The engine first profiles the existing memory usage.
|
|
||||||
Then, it allocates the remaining memory for KV blocks.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
You may limit the usage of GPU memory
|
|
||||||
by adjusting the `gpu_memory_utilization` parameter.
|
|
||||||
"""
|
"""
|
||||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
return self.driver_worker.determine_num_available_blocks()
|
||||||
num_gpu_blocks, num_cpu_blocks = (
|
|
||||||
self.driver_worker.profile_num_available_blocks(
|
|
||||||
block_size=self.cache_config.block_size,
|
|
||||||
gpu_memory_utilization=self.cache_config.
|
|
||||||
gpu_memory_utilization,
|
|
||||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
|
||||||
cache_dtype=self.cache_config.cache_dtype,
|
|
||||||
))
|
|
||||||
|
|
||||||
if self.cache_config.forced_num_gpu_blocks is not None:
|
|
||||||
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
|
|
||||||
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
|
|
||||||
f"{forced_num_gpu_blocks=}")
|
|
||||||
num_gpu_blocks = forced_num_gpu_blocks
|
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
||||||
|
"""Initialize the KV cache by invoking the underlying worker.
|
||||||
|
"""
|
||||||
|
# NOTE: This is logged in the executor because there can be >1 worker
|
||||||
|
# with other executors. We could log in the engine level, but work
|
||||||
|
# remains to abstract away the device for non-GPU configurations.
|
||||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||||
f"# CPU blocks: {num_cpu_blocks}")
|
f"# CPU blocks: {num_cpu_blocks}")
|
||||||
|
|
||||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
self.model_config.max_model_len)
|
|
||||||
|
|
||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
||||||
|
|
||||||
# Initialize the cache.
|
|
||||||
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
|
|
||||||
# Warm up the model. This includes capturing the model into CUDA graph
|
|
||||||
# if enforce_eager is False.
|
|
||||||
self.driver_worker.warm_up_model()
|
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
|||||||
@ -25,7 +25,6 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
|
||||||
assert lora_config is None, "LoRA is not supported for Neuron backend."
|
assert lora_config is None, "LoRA is not supported for Neuron backend."
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
@ -33,12 +32,6 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
assert (not speculative_config
|
assert (not speculative_config
|
||||||
), "Speculative decoding not yet supported for Neuron backend."
|
), "Speculative decoding not yet supported for Neuron backend."
|
||||||
|
|
||||||
# Set the number of GPU blocks to be the same as the maximum number of
|
|
||||||
# sequences that can be processed in a single batch. This is equivalent
|
|
||||||
# to schedule without PagedAttention.
|
|
||||||
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
|
|
||||||
self.cache_config.num_cpu_blocks = 0
|
|
||||||
|
|
||||||
# Instantiate the worker and load the model to the device.
|
# Instantiate the worker and load the model to the device.
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
|
|
||||||
@ -54,6 +47,18 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
|
"""Determine the number of available KV blocks by invoking the
|
||||||
|
underlying worker.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache by invoking the underlying worker.
|
||||||
|
"""
|
||||||
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
@ -68,16 +73,13 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise NotImplementedError(
|
return self.driver_worker.add_lora(lora_request)
|
||||||
"LoRA is not implemented for neuron backend.")
|
|
||||||
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError(
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
"LoRA is not implemented for neuron backend.")
|
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> List[int]:
|
||||||
raise NotImplementedError(
|
return self.driver_worker.list_loras()
|
||||||
"LoRA is not implemented for neuron backend.")
|
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# NeuronExecutor will always be healthy as long as
|
# NeuronExecutor will always be healthy as long as
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
|||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.executor.utils import check_block_size_valid
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
@ -65,9 +64,6 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
self._init_workers_ray(placement_group)
|
self._init_workers_ray(placement_group)
|
||||||
|
|
||||||
# Profile the memory usage and initialize the cache.
|
|
||||||
self._init_cache()
|
|
||||||
|
|
||||||
self.forward_dag = None
|
self.forward_dag = None
|
||||||
if USE_RAY_COMPILED_DAG:
|
if USE_RAY_COMPILED_DAG:
|
||||||
self.forward_dag = self._compiled_ray_dag()
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
@ -154,8 +150,8 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||||
device_config = copy.deepcopy(self.device_config)
|
device_config = copy.deepcopy(self.device_config)
|
||||||
lora_config = copy.deepcopy(self.lora_config)
|
lora_config = copy.deepcopy(self.lora_config)
|
||||||
|
cache_config = copy.deepcopy(self.cache_config)
|
||||||
vision_language_config = copy.deepcopy(self.vision_language_config)
|
vision_language_config = copy.deepcopy(self.vision_language_config)
|
||||||
kv_cache_dtype = self.cache_config.cache_dtype
|
|
||||||
|
|
||||||
# Initialize the actual workers with the Worker class.
|
# Initialize the actual workers with the Worker class.
|
||||||
for rank, (worker, (node_id, _)) in enumerate(
|
for rank, (worker, (node_id, _)) in enumerate(
|
||||||
@ -165,32 +161,32 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
local_rank = node_workers[node_id].index(rank)
|
local_rank = node_workers[node_id].index(rank)
|
||||||
worker.init_worker.remote(
|
worker.init_worker.remote(
|
||||||
lambda rank=rank, local_rank=local_rank: Worker(
|
lambda rank=rank, local_rank=local_rank: Worker(
|
||||||
model_config,
|
model_config=model_config,
|
||||||
parallel_config,
|
parallel_config=parallel_config,
|
||||||
scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
device_config,
|
device_config=device_config,
|
||||||
local_rank,
|
cache_config=cache_config,
|
||||||
rank,
|
local_rank=local_rank,
|
||||||
distributed_init_method,
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
# Initialize the driver worker with the Worker class.
|
# Initialize the driver worker with the Worker class.
|
||||||
driver_rank = 0
|
driver_rank = 0
|
||||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||||
self.driver_worker = Worker(
|
self.driver_worker = Worker(
|
||||||
self.model_config,
|
model_config=self.model_config,
|
||||||
self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
self.device_config,
|
device_config=self.device_config,
|
||||||
driver_local_rank,
|
cache_config=self.cache_config,
|
||||||
driver_rank,
|
local_rank=driver_local_rank,
|
||||||
distributed_init_method,
|
rank=driver_rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -201,35 +197,18 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
max_parallel_loading_workers,
|
max_parallel_loading_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
"""Profiles the memory usage and initializes the KV cache.
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
The engine will first conduct a profiling of the existing memory usage.
|
This invokes `determine_num_available_blocks` on each worker and takes
|
||||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
the min of the results, guaranteeing that the selected cache sizes are
|
||||||
that can be allocated with the remaining free memory.
|
compatible with all workers.
|
||||||
More details can be found in the
|
|
||||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
|
||||||
from class :class:`~vllm.worker.Worker`.
|
|
||||||
|
|
||||||
Afterwards, as there may be multiple workers,
|
Returns:
|
||||||
we take the minimum number of blocks across all workers
|
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||||
to ensure this can be applied to all of them.
|
|
||||||
|
|
||||||
Finally, the engine will initialize the KV cache
|
|
||||||
with the calculated number of blocks.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
You may limit the usage of GPU memory
|
|
||||||
by adjusting the `gpu_memory_utilization` parameter.
|
|
||||||
"""
|
"""
|
||||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
num_blocks = self._run_workers(
|
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||||
"profile_num_available_blocks",
|
|
||||||
block_size=self.cache_config.block_size,
|
|
||||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
|
||||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
|
||||||
cache_dtype=self.cache_config.cache_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since we use a shared centralized controller, we take the minimum
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
# number of blocks across all workers to make sure all the memory
|
# number of blocks across all workers to make sure all the memory
|
||||||
@ -237,26 +216,25 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
|
||||||
if self.cache_config.forced_num_gpu_blocks is not None:
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
|
|
||||||
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
|
|
||||||
f"{forced_num_gpu_blocks=}")
|
|
||||||
num_gpu_blocks = forced_num_gpu_blocks
|
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache in all workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||||
|
# greater than one. We could log in the engine, but not all executors
|
||||||
|
# have GPUs.
|
||||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||||
f"# CPU blocks: {num_cpu_blocks}")
|
f"# CPU blocks: {num_cpu_blocks}")
|
||||||
|
|
||||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
|
||||||
self.model_config.max_model_len)
|
|
||||||
|
|
||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
# Initialize the cache.
|
self._run_workers("initialize_cache",
|
||||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
# Warm up the model. This includes capturing the model into CUDA graph
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
# if enforce_eager is False.
|
|
||||||
self._run_workers("warm_up_model")
|
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
|||||||
@ -1,13 +0,0 @@
|
|||||||
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
|
|
||||||
if num_gpu_blocks <= 0:
|
|
||||||
raise ValueError("No available memory for the cache blocks. "
|
|
||||||
"Try increasing `gpu_memory_utilization` when "
|
|
||||||
"initializing the engine.")
|
|
||||||
max_seq_len = block_size * num_gpu_blocks
|
|
||||||
if max_model_len > max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The model's max seq len ({max_model_len}) "
|
|
||||||
"is larger than the maximum number of tokens that can be "
|
|
||||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
|
||||||
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
|
||||||
"initializing the engine.")
|
|
||||||
@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import CacheConfig
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
|
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
|
||||||
SequenceGroupOutput, SequenceOutput)
|
SequenceGroupOutput, SequenceOutput)
|
||||||
@ -15,9 +14,10 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|||||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
|
||||||
|
|
||||||
class SpecDecodeWorker:
|
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
"""Worker which implements speculative decoding.
|
"""Worker which implements speculative decoding.
|
||||||
|
|
||||||
Speculative decoding reduces decoding per-token latency by using a proposal
|
Speculative decoding reduces decoding per-token latency by using a proposal
|
||||||
@ -94,10 +94,7 @@ class SpecDecodeWorker:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
vocab_size=self._vocab_size)
|
vocab_size=self._vocab_size)
|
||||||
|
|
||||||
def profile_num_available_blocks(self, block_size: int,
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
gpu_memory_utilization: float,
|
|
||||||
cpu_swap_space: int,
|
|
||||||
cache_dtype: str) -> Tuple[int, int]:
|
|
||||||
"""Determine the number of cache blocks to use.
|
"""Determine the number of cache blocks to use.
|
||||||
|
|
||||||
This is done by profiling the scorer model (which is typically the
|
This is done by profiling the scorer model (which is typically the
|
||||||
@ -106,27 +103,26 @@ class SpecDecodeWorker:
|
|||||||
such that the number of blocks is equal in both KV caches.
|
such that the number of blocks is equal in both KV caches.
|
||||||
"""
|
"""
|
||||||
num_gpu_blocks, num_cpu_blocks = (
|
num_gpu_blocks, num_cpu_blocks = (
|
||||||
self.scorer_worker.profile_num_available_blocks(
|
self.scorer_worker.determine_num_available_blocks())
|
||||||
block_size, gpu_memory_utilization, cpu_swap_space,
|
|
||||||
cache_dtype))
|
|
||||||
|
|
||||||
scorer_cache_block_size_bytes = (
|
scorer_cache_block_size_bytes = (
|
||||||
self.scorer_worker.get_cache_block_size_bytes(
|
self.scorer_worker.get_cache_block_size_bytes())
|
||||||
block_size, cache_dtype))
|
|
||||||
proposer_cache_block_size_bytes = (
|
proposer_cache_block_size_bytes = (
|
||||||
self.proposer_worker.get_cache_block_size_bytes(
|
self.proposer_worker.get_cache_block_size_bytes())
|
||||||
block_size, cache_dtype))
|
|
||||||
|
|
||||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||||
num_gpu_blocks)
|
num_gpu_blocks)
|
||||||
return new_num_gpu_blocks, num_cpu_blocks
|
return new_num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def init_cache_engine(self, cache_config: CacheConfig):
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
"""Initialize the cache engine of the scorer and proposer workers.
|
"""Initialize the cache engine of the scorer and proposer workers.
|
||||||
"""
|
"""
|
||||||
self.scorer_worker.init_cache_engine(cache_config)
|
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||||
self.proposer_worker.init_cache_engine(cache_config)
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -351,6 +347,16 @@ class SpecDecodeWorker:
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.scorer_worker.device
|
return self.scorer_worker.device
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self):
|
||||||
|
"""Return the size of a cache block in bytes.
|
||||||
|
|
||||||
|
This function is only used to compose workers within a SpecDecodeWorker.
|
||||||
|
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
|
||||||
|
undefined for now, although it could be implemented in the future.
|
||||||
|
See https://arxiv.org/abs/2308.04623.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
||||||
proposer_cache_block_size_bytes: int,
|
proposer_cache_block_size_bytes: int,
|
||||||
|
|||||||
@ -82,8 +82,7 @@ class CacheEngine:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cache_block_size(
|
def get_cache_block_size(
|
||||||
block_size: int,
|
cache_config: CacheConfig,
|
||||||
cache_dtype: str,
|
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -91,13 +90,13 @@ class CacheEngine:
|
|||||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
value_cache_block = key_cache_block
|
value_cache_block = key_cache_block
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
if cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
dtype = model_config.dtype
|
dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
dtype_size = _get_dtype_size(dtype)
|
dtype_size = _get_dtype_size(dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -112,7 +113,7 @@ class CPUCacheEngine:
|
|||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
class CPUWorker:
|
class CPUWorker(LoraNotSupportedWorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||||
|
|
||||||
Each worker is associated with a single CPU socket. The worker is
|
Each worker is associated with a single CPU socket. The worker is
|
||||||
@ -127,6 +128,7 @@ class CPUWorker:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
@ -138,6 +140,7 @@ class CPUWorker:
|
|||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
@ -154,8 +157,7 @@ class CPUWorker:
|
|||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# initialize_cache.
|
||||||
self.cache_config = None
|
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.cpu_cache = None
|
self.cpu_cache = None
|
||||||
|
|
||||||
@ -167,28 +169,70 @@ class CPUWorker:
|
|||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
def get_cpu_cache_block_num(
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
self,
|
"""Determine the number of blocks available for the KV cache.
|
||||||
block_size: int,
|
|
||||||
cache_space: int,
|
This determines how many KV blocks can fit into the configured CPU
|
||||||
cache_dtype: str,
|
KV cache space.
|
||||||
) -> int:
|
|
||||||
"""
|
Note that since vLLM assumes a block resides on GPU if it can be
|
||||||
Args:
|
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
|
||||||
block_size: The size of the cache block.
|
This allows us to reuse the scheduler of vLLM without generalizing it
|
||||||
cache_space: The size of the CPU KV cache space in bytes.
|
to different devices.
|
||||||
"""
|
"""
|
||||||
# For CPU device, the block number will be calculated based on the
|
# For CPU device, the block number will be calculated based on the
|
||||||
# cpu_kvcache_space.
|
# cpu_kvcache_space.
|
||||||
cache_block_size = CPUCacheEngine.get_cache_block_size(
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
block_size, cache_dtype, self.model_config, self.parallel_config)
|
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
|
||||||
num_cpu_blocks = int(cache_space // cache_block_size)
|
cache_block_size)
|
||||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
|
||||||
return num_cpu_blocks
|
# Note: To reuse the cache management procedure,
|
||||||
|
# use cpu cache as 'gpu cache'.
|
||||||
|
num_gpu_blocks = num_cpu_blocks
|
||||||
|
num_cpu_blocks = 0
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
self.cache_config = cache_config
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache. Currently, swappable CPU memory is not
|
||||||
|
supported.
|
||||||
|
|
||||||
|
Since this worker does not support GPUs, we use the num_gpu_blocks to
|
||||||
|
determine how many non-swappable CPU blocks to allocate.
|
||||||
|
"""
|
||||||
|
assert (num_cpu_blocks == 0
|
||||||
|
), f"{type(self)} does not support swappable cache"
|
||||||
|
|
||||||
|
# Note: To reuse the cache management procedure,
|
||||||
|
# use cpu cache as 'gpu cache'.
|
||||||
|
num_cpu_blocks = num_gpu_blocks
|
||||||
|
|
||||||
|
self._validate_num_cpu_blocks(num_cpu_blocks)
|
||||||
|
self.cache_config.num_gpu_blocks = num_cpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = 0
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self._init_cache_engine()
|
||||||
|
|
||||||
|
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||||
|
"""Raise errors if the num_cpu_blocks is invalid.
|
||||||
|
"""
|
||||||
|
if num_cpu_blocks <= 0:
|
||||||
|
raise ValueError("No available memory for the cache blocks. "
|
||||||
|
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||||
|
if self.model_config.max_model_len > max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||||
|
"is larger than the maximum number of tokens that can be "
|
||||||
|
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||||
|
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
def _init_cache_engine(self) -> None:
|
||||||
self.cache_engine = CPUCacheEngine(self.cache_config,
|
self.cache_engine = CPUCacheEngine(self.cache_config,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.parallel_config,
|
self.parallel_config,
|
||||||
@ -264,3 +308,10 @@ class CPUWorker:
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
"""Return the size in bytes of a single KV cache block.
|
||||||
|
"""
|
||||||
|
return CPUCacheEngine.get_cache_block_size(
|
||||||
|
self.cache_config.block_size, self.cache_config.cache_dtype,
|
||||||
|
self.model_config, self.parallel_config)
|
||||||
|
|||||||
@ -4,14 +4,15 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
|
||||||
|
|
||||||
class NeuronWorker:
|
class NeuronWorker(LoraNotSupportedWorkerBase):
|
||||||
"""A worker class that executes the model on a group of neuron cores.
|
"""A worker class that executes the model on a group of neuron cores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -21,11 +22,13 @@ class NeuronWorker:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
|
||||||
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
||||||
scheduler_config, device_config)
|
scheduler_config, device_config)
|
||||||
@ -37,6 +40,35 @@ class NeuronWorker:
|
|||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
|
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||||||
|
|
||||||
|
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||||||
|
"""
|
||||||
|
# Set the number of GPU blocks to be the same as the maximum number of
|
||||||
|
# sequences that can be processed in a single batch. This is equivalent
|
||||||
|
# to schedule without PagedAttention.
|
||||||
|
num_gpu_blocks = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Swap not yet supported with Neuron backend.
|
||||||
|
num_cpu_blocks = 0
|
||||||
|
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Different values are not tested.
|
||||||
|
assert num_cpu_blocks == 0
|
||||||
|
assert num_gpu_blocks == self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -50,3 +82,10 @@ class NeuronWorker:
|
|||||||
|
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list)
|
output = self.model_runner.execute_model(seq_group_metadata_list)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
"""Determine the size in bytes of a cache block.
|
||||||
|
|
||||||
|
This is required for speculative decoding; it is not yet implemented.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -19,9 +19,10 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker(WorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a GPU.
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
Each worker is associated with a single GPU. The worker is responsible for
|
Each worker is associated with a single GPU. The worker is responsible for
|
||||||
@ -35,18 +36,19 @@ class Worker:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
@ -66,12 +68,11 @@ class Worker:
|
|||||||
scheduler_config,
|
scheduler_config,
|
||||||
device_config,
|
device_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
vision_language_config=vision_language_config)
|
vision_language_config=vision_language_config)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# initialize_cache.
|
||||||
self.cache_config = None
|
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
|
|
||||||
@ -107,20 +108,17 @@ class Worker:
|
|||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_num_available_blocks(
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
self,
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
block_size: int,
|
KV blocks may be allocated without OOMs.
|
||||||
gpu_memory_utilization: float,
|
|
||||||
cpu_swap_space: int,
|
|
||||||
cache_dtype: str,
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""Profiles the peak memory usage of the model and returns the maximum
|
|
||||||
number of GPU and CPU cache blocks that can be allocated.
|
|
||||||
|
|
||||||
Args:
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
block_size: The size of the cache block.
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
gpu_memory_utilization: The fraction of the total GPU memory to use.
|
that can be allocated with the remaining free memory.
|
||||||
cpu_swap_space: The size of the CPU swap space in bytes.
|
|
||||||
|
.. tip::
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
"""
|
"""
|
||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
@ -141,12 +139,12 @@ class Worker:
|
|||||||
"Error in memory profiling. This happens when the GPU memory was "
|
"Error in memory profiling. This happens when the GPU memory was "
|
||||||
"not properly cleaned up before initializing the vLLM instance.")
|
"not properly cleaned up before initializing the vLLM instance.")
|
||||||
|
|
||||||
cache_block_size = self.get_cache_block_size_bytes(
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
block_size, cache_dtype)
|
|
||||||
num_gpu_blocks = int(
|
num_gpu_blocks = int(
|
||||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||||
cache_block_size)
|
peak_memory) // cache_block_size)
|
||||||
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
|
cache_block_size)
|
||||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
if self.model_runner.lora_manager:
|
if self.model_runner.lora_manager:
|
||||||
@ -155,14 +153,30 @@ class Worker:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
self.cache_config = cache_config
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||||
|
|
||||||
|
This also warms up the model, which may record CUDA graphs.
|
||||||
|
"""
|
||||||
|
raise_if_cache_size_invalid(num_gpu_blocks,
|
||||||
|
self.cache_config.block_size,
|
||||||
|
self.model_config.max_model_len)
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self._init_cache_engine()
|
||||||
|
self._warm_up_model()
|
||||||
|
|
||||||
|
def _init_cache_engine(self):
|
||||||
|
assert self.cache_config.num_gpu_blocks is not None
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
def _warm_up_model(self) -> None:
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model(self.gpu_cache)
|
self.model_runner.capture_model(self.gpu_cache)
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
@ -239,11 +253,10 @@ class Worker:
|
|||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return self.model_runner.vocab_size
|
return self.model_runner.vocab_size
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self, block_size: int,
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
cache_dtype: str) -> int:
|
|
||||||
"""Get the size of the KV cache block size in bytes.
|
"""Get the size of the KV cache block size in bytes.
|
||||||
"""
|
"""
|
||||||
return CacheEngine.get_cache_block_size(block_size, cache_dtype,
|
return CacheEngine.get_cache_block_size(self.cache_config,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
|
|
||||||
@ -300,3 +313,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
|||||||
f"{compute_capability[0]}.{compute_capability[1]}. "
|
f"{compute_capability[0]}.{compute_capability[1]}. "
|
||||||
"You can use float16 instead by explicitly setting the"
|
"You can use float16 instead by explicitly setting the"
|
||||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|
||||||
|
|
||||||
|
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
|
||||||
|
max_model_len) -> None:
|
||||||
|
if num_gpu_blocks <= 0:
|
||||||
|
raise ValueError("No available memory for the cache blocks. "
|
||||||
|
"Try increasing `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
max_seq_len = block_size * num_gpu_blocks
|
||||||
|
if max_model_len > max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model's max seq len ({max_model_len}) "
|
||||||
|
"is larger than the maximum number of tokens that can be "
|
||||||
|
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||||
|
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|||||||
83
vllm/worker/worker_base.py
Normal file
83
vllm/worker/worker_base.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerBase(ABC):
|
||||||
|
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||||
|
different hardware.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_device(self) -> None:
|
||||||
|
"""Initialize device state, such as loading the model or other on-device
|
||||||
|
memory allocations.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||||
|
"""Determine the number of available blocks for the GPU KV cache and
|
||||||
|
swappable CPU KV cache.
|
||||||
|
|
||||||
|
The implementation may run profiling or other heuristics to determine
|
||||||
|
the size of caches.
|
||||||
|
|
||||||
|
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||||
|
are blocks that are "active" on the device and can be appended to.
|
||||||
|
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||||
|
appended to.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache with the given size in blocks.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute_model(self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||||
|
"""Executes one model step on the given sequences."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_cache_block_size_bytes() -> int:
|
||||||
|
"""Return the size of a single cache block, in bytes. Used in
|
||||||
|
speculative decoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LoraNotSupportedWorkerBase(WorkerBase):
|
||||||
|
"""Partial implementation of WorkerBase that raises exceptions when LoRA
|
||||||
|
methods are invoked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
Loading…
x
Reference in New Issue
Block a user