Allow users to specify kv cache memory size (#21489)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Boyuan Feng 2025-09-11 06:41:07 -07:00 committed by GitHub
parent fd1ce98cdd
commit 94e6b2d55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 236 additions and 47 deletions

View File

@ -113,6 +113,15 @@ class CacheConfig:
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
kv_cache_memory_bytes: Optional[int] = None
"""Size of KV Cache per GPU in bytes. By default, this is set to None
and vllm can automatically infer the kv cache size based on
gpu_memory_utilization. However, users may want to manually specify
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
control of how much memory gets used when compared with using
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
# Special case for large integers
if name in {"max_model_len", "max_num_batched_tokens"}:
human_readable_ints = {
"max_model_len",
"max_num_batched_tokens",
"kv_cache_memory_bytes",
}
if name in human_readable_ints:
kwargs[name]["type"] = human_readable_int
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif (contains_type(type_hints, dict)
@ -335,6 +341,7 @@ class EngineArgs:
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
@ -734,6 +741,8 @@ class EngineArgs:
cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
cache_group.add_argument("--gpu-memory-utilization",
**cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument("--kv-cache-memory-bytes",
**cache_kwargs["kv_cache_memory_bytes"])
cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
cache_group.add_argument("--kv-cache-dtype",
**cache_kwargs["cache_dtype"])
@ -1174,6 +1183,7 @@ class EngineArgs:
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,

View File

@ -278,7 +278,8 @@ class LLMEngine:
self.cache_config.block_size,
"gpu_memory_utilization":
self.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
self.cache_config.kv_cache_memory_bytes,
# Quantization
"quantization":
self.model_config.quantization,

View File

@ -110,6 +110,14 @@ class LLM:
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
this is set to None and vllm can automatically infer the kv cache
size based on gpu_memory_utilization. However, users may want to
manually specify the kv cache memory size. kv_cache_memory_bytes
allows more fine-grain control of how much memory gets used when
compared with using gpu_memory_memory_utilization. Note that
kv_cache_memory_bytes (when not-None) ignores
gpu_memory_utilization
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
@ -184,6 +192,7 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None,
kv_cache_memory_bytes: Optional[int] = None,
compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
logits_processors: Optional[list[Union[str,
@ -251,6 +260,7 @@ class LLM:
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
kv_cache_memory_bytes=kv_cache_memory_bytes,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,

View File

@ -2791,7 +2791,10 @@ def memory_profiling(
result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
non_torch_memory = result.non_torch_increase
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501

View File

@ -355,7 +355,8 @@ def report_usage_stats(
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
vllm_config.cache_config.kv_cache_memory_bytes,
# Quantization
"quantization":
vllm_config.model_config.quantization,

View File

@ -3041,12 +3041,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.encoder_cache.clear()
gc.collect()
def capture_model(self) -> None:
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`")
return
return 0
else:
self.initialize_cudagraph_capture()
@ -3117,6 +3117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
return cuda_graph_size
def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode,

View File

@ -231,18 +231,40 @@ class Worker(WorkerBase):
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
GiB = lambda b: b / GiB_bytes
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
msg = (
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly.")
logger.info(msg)
return kv_cache_memory_bytes
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(
self.model_runner.model_memory_usage)) as profile_result:
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
@ -254,7 +276,7 @@ class Worker(WorkerBase):
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container.")
available_kv_cache_memory = self.requested_memory \
self.available_kv_cache_memory_bytes = self.requested_memory \
- profile_result.non_kv_cache_memory
unrequested_memory = self.init_snapshot.free_memory \
@ -274,10 +296,10 @@ class Worker(WorkerBase):
)
logger.debug(profile_result)
logger.info("Available KV cache memory: %.2f GiB",
GiB(available_kv_cache_memory))
GiB(self.available_kv_cache_memory_bytes))
gc.collect()
return int(available_kv_cache_memory)
return int(self.available_kv_cache_memory_bytes)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
@ -317,8 +339,56 @@ class Worker(WorkerBase):
# cuda graph capture.
kernel_warmup(self)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
cuda_graph_memory_bytes = self.model_runner.capture_model()
if (self.cache_config.kv_cache_memory_bytes is None
and hasattr(self, "peak_activation_memory")):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
non_kv_cache_memory = (self.model_runner.model_memory_usage +
self.peak_activation_memory +
self.non_torch_memory +
cuda_graph_memory_bytes)
kv_cache_memory_bytes_to_gpu_limit = (
self.init_snapshot.free_memory - non_kv_cache_memory -
redundancy_buffer_memory)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory) - non_kv_cache_memory -
redundancy_buffer_memory)
msg = (
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
f"requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
logger.info(msg)
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory

View File

@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return self.lora_manager.list_adapters()
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int:
"""Cuda graph capture a model and return cudagraph memory
consumption in bytes.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / GiB_bytes)
return cuda_graph_size
def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str,

View File

@ -229,6 +229,67 @@ class Worker(LocalOrDistributedWorkerBase):
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
def determine_available_kv_cache_memory(self,
total_gpu_memory: int) -> float:
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
GiB = lambda b: b / GiB_bytes
msg = (
f"Initial free memory "
f"{GiB(self.baseline_snapshot.free_memory):.2f} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly.")
logger.info(msg)
return self.cache_config.kv_cache_memory_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self.non_torch_memory = result.non_torch_increase
self.peak_activation_memory = result.torch_peak_increase
self._assert_memory_footprint_increased_during_profiling()
self.requested_memory = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
self.available_kv_cache_memory = (self.requested_memory -
result.non_kv_cache_memory)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return self.available_kv_cache_memory
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
@ -248,20 +309,8 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self._assert_memory_footprint_increased_during_profiling()
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory)
available_kv_cache_memory = self.determine_available_kv_cache_memory(
total_gpu_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
@ -276,23 +325,6 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
# Final cleanup
gc.collect()
@ -382,8 +414,58 @@ class Worker(LocalOrDistributedWorkerBase):
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
cuda_graph_memory_bytes = self.model_runner.capture_model(
self.gpu_cache)
if (self.cache_config.kv_cache_memory_bytes is None
and hasattr(self, "peak_activation_memory")):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
non_kv_cache_memory = (self.model_runner.model_memory_usage +
self.peak_activation_memory +
self.non_torch_memory +
cuda_graph_memory_bytes)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
kv_cache_memory_bytes_to_gpu_limit = (
self.baseline_snapshot.free_memory - non_kv_cache_memory -
redundancy_buffer_memory)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory) - non_kv_cache_memory -
redundancy_buffer_memory)
msg = (
f"Free memory on device "
f"({GiB(self.baseline_snapshot.free_memory)}/"
f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
f"requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory)} bytes.")
logger.info(msg)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)