mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
Optimize KV cache distribution for asymmetric pipeline parallelism (#25164)
Signed-off-by: gholmes829 <g.holmes429@gmail.com>
This commit is contained in:
parent
7e4cd070b0
commit
d100d78eb3
@ -681,10 +681,10 @@ def test_get_kv_cache_configs_multiple_workers():
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"]
|
||||
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
|
||||
),
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer2"]
|
||||
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
@ -718,7 +718,7 @@ def test_get_kv_cache_configs_multiple_workers():
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"]
|
||||
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
@ -802,7 +802,7 @@ def test_get_kv_cache_configs_multiple_workers():
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"]
|
||||
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
@ -813,7 +813,7 @@ def test_get_kv_cache_configs_multiple_workers():
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"]
|
||||
size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
|
||||
@ -124,7 +124,7 @@ class CacheConfig:
|
||||
gpu_memory_utilization. However, users may want to manually specify
|
||||
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
|
||||
control of how much memory gets used when compared with using
|
||||
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
|
||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||
(when not-None) ignores gpu_memory_utilization"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
||||
@ -143,7 +143,7 @@ class LLM:
|
||||
size based on gpu_memory_utilization. However, users may want to
|
||||
manually specify the kv cache memory size. kv_cache_memory_bytes
|
||||
allows more fine-grain control of how much memory gets used when
|
||||
compared with using gpu_memory_memory_utilization. Note that
|
||||
compared with using gpu_memory_utilization. Note that
|
||||
kv_cache_memory_bytes (when not-None) ignores
|
||||
gpu_memory_utilization
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
|
||||
@ -1113,35 +1113,12 @@ def get_kv_cache_config_from_groups(
|
||||
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
|
||||
)
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups])
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config
|
||||
)
|
||||
logger.info(
|
||||
"Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str,
|
||||
max_concurrency,
|
||||
)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
@ -1265,6 +1242,45 @@ def generate_scheduler_kv_cache_config(
|
||||
return cfg
|
||||
|
||||
|
||||
def _report_kv_cache_config(
|
||||
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig
|
||||
) -> None:
|
||||
"""
|
||||
Log resolved KV cache configuration.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_config: The resolved KV cache configuration
|
||||
"""
|
||||
min_block_size = min(
|
||||
[group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups]
|
||||
)
|
||||
|
||||
# Log the KV cache size and maximum concurrency.
|
||||
num_tokens = (
|
||||
kv_cache_config.num_blocks
|
||||
// len(kv_cache_config.kv_cache_groups)
|
||||
* min_block_size
|
||||
)
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config
|
||||
)
|
||||
logger.info(
|
||||
"Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str,
|
||||
max_concurrency,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_cache_configs(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_specs: list[dict[str, KVCacheSpec]],
|
||||
@ -1284,7 +1300,8 @@ def get_kv_cache_configs(
|
||||
3. Generate the KV cache configs for each worker based on the KV cache
|
||||
grouping strategy. (This is reasonable because the layer ratio of
|
||||
different PP stages are similar.)
|
||||
4. Change the num_blocks of each worker to the smallest among all workers.
|
||||
4. Change the num_blocks of each worker to the smallest among all workers
|
||||
and shrink tensor sizes proportionally to avoid allocating unused memory.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
@ -1345,13 +1362,22 @@ def get_kv_cache_configs(
|
||||
)
|
||||
)
|
||||
|
||||
# Change the num_blocks of each rank to the smallest among all ranks. We
|
||||
# do not need to shrink the tensor size because it is valid to only use the
|
||||
# first `num_blocks` blocks of the tensor.
|
||||
# Change the num_blocks of each rank to the smallest among all ranks.
|
||||
# We also need to shrink the tensor size proportionally to avoid
|
||||
# allocating unused memory.
|
||||
min_num_blocks = min(
|
||||
kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs
|
||||
)
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
num_blocks_old = kv_cache_config.num_blocks
|
||||
kv_cache_config.num_blocks = min_num_blocks
|
||||
|
||||
# Shrink tensor size proportionally
|
||||
for tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert tensor.size % num_blocks_old == 0
|
||||
tensor.size = tensor.size // num_blocks_old * min_num_blocks
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) > 0:
|
||||
_report_kv_cache_config(vllm_config, kv_cache_config)
|
||||
|
||||
return kv_cache_configs
|
||||
|
||||
@ -253,10 +253,10 @@ class Worker(WorkerBase):
|
||||
self.model_runner.profile_run()
|
||||
|
||||
msg = (
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"skipped memory profiling. This does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user