Allow users to specify kv cache memory size (#21489)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Boyuan Feng 2025-09-11 06:41:07 -07:00 committed by GitHub
parent fd1ce98cdd
commit 94e6b2d55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 236 additions and 47 deletions

View File

@ -113,6 +113,15 @@ class CacheConfig:
necessary for implementing this optimization in some models (e.g. Gemma3n) necessary for implementing this optimization in some models (e.g. Gemma3n)
""" """
kv_cache_memory_bytes: Optional[int] = None
"""Size of KV Cache per GPU in bytes. By default, this is set to None
and vllm can automatically infer the kv cache size based on
gpu_memory_utilization. However, users may want to manually specify
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
control of how much memory gets used when compared with using
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
(when not-None) ignores gpu_memory_utilization"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int kwargs[name]["type"] = int
# Special case for large integers # Special case for large integers
if name in {"max_model_len", "max_num_batched_tokens"}: human_readable_ints = {
"max_model_len",
"max_num_batched_tokens",
"kv_cache_memory_bytes",
}
if name in human_readable_ints:
kwargs[name]["type"] = human_readable_int kwargs[name]["type"] = human_readable_int
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
elif (contains_type(type_hints, dict) elif (contains_type(type_hints, dict)
@ -335,6 +341,7 @@ class EngineArgs:
swap_space: float = CacheConfig.swap_space swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: Optional[ max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
@ -734,6 +741,8 @@ class EngineArgs:
cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
cache_group.add_argument("--gpu-memory-utilization", cache_group.add_argument("--gpu-memory-utilization",
**cache_kwargs["gpu_memory_utilization"]) **cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument("--kv-cache-memory-bytes",
**cache_kwargs["kv_cache_memory_bytes"])
cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
cache_group.add_argument("--kv-cache-dtype", cache_group.add_argument("--kv-cache-dtype",
**cache_kwargs["cache_dtype"]) **cache_kwargs["cache_dtype"])
@ -1174,6 +1183,7 @@ class EngineArgs:
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space, swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype, cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free, is_attention_free=model_config.is_attention_free,

View File

@ -278,7 +278,8 @@ class LLMEngine:
self.cache_config.block_size, self.cache_config.block_size,
"gpu_memory_utilization": "gpu_memory_utilization":
self.cache_config.gpu_memory_utilization, self.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
self.cache_config.kv_cache_memory_bytes,
# Quantization # Quantization
"quantization": "quantization":
self.model_config.quantization, self.model_config.quantization,

View File

@ -110,6 +110,14 @@ class LLM:
values will increase the KV cache size and thus improve the model's values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of- throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors. memory (OOM) errors.
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
this is set to None and vllm can automatically infer the kv cache
size based on gpu_memory_utilization. However, users may want to
manually specify the kv cache memory size. kv_cache_memory_bytes
allows more fine-grain control of how much memory gets used when
compared with using gpu_memory_memory_utilization. Note that
kv_cache_memory_bytes (when not-None) ignores
gpu_memory_utilization
swap_space: The size (GiB) of CPU memory per GPU to use as swap space. swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all when their `best_of` sampling parameters are larger than 1. If all
@ -184,6 +192,7 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None,
kv_cache_memory_bytes: Optional[int] = None,
compilation_config: Optional[Union[int, dict[str, Any], compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None, CompilationConfig]] = None,
logits_processors: Optional[list[Union[str, logits_processors: Optional[list[Union[str,
@ -251,6 +260,7 @@ class LLM:
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
seed=seed, seed=seed,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
kv_cache_memory_bytes=kv_cache_memory_bytes,
swap_space=swap_space, swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb, cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,

View File

@ -2791,7 +2791,10 @@ def memory_profiling(
result.torch_peak_increase = diff_profile.torch_peak result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp result.profile_time = diff_profile.timestamp
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
non_torch_memory = result.non_torch_increase
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501

View File

@ -355,7 +355,8 @@ def report_usage_stats(
vllm_config.cache_config.block_size, vllm_config.cache_config.block_size,
"gpu_memory_utilization": "gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization, vllm_config.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
vllm_config.cache_config.kv_cache_memory_bytes,
# Quantization # Quantization
"quantization": "quantization":
vllm_config.model_config.quantization, vllm_config.model_config.quantization,

View File

@ -3041,12 +3041,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.encoder_cache.clear() self.encoder_cache.clear()
gc.collect() gc.collect()
def capture_model(self) -> None: def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning( logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, " "Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`") "ensure `cudagraph_mode` was not manually set to `NONE`")
return return 0
else: else:
self.initialize_cudagraph_capture() self.initialize_cudagraph_capture()
@ -3117,6 +3117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This usually takes 5~20 seconds. # This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30)) elapsed_time, cuda_graph_size / (1 << 30))
return cuda_graph_size
def _capture_cudagraphs(self, compilation_cases: list[int], def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode, cudagraph_runtime_mode: CUDAGraphMode,

View File

@ -231,18 +231,40 @@ class Worker(WorkerBase):
You may limit the usage of GPU memory You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter. by adjusting the `gpu_memory_utilization` parameter.
""" """
GiB = lambda b: b / GiB_bytes
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
msg = (
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly.")
logger.info(msg)
return kv_cache_memory_bytes
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
with memory_profiling( with memory_profiling(
self.init_snapshot, self.init_snapshot,
weights_memory=int( weights_memory=int(self.model_runner.model_memory_usage),
self.model_runner.model_memory_usage)) as profile_result: ) as profile_result:
self.model_runner.profile_run() self.model_runner.profile_run()
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
free_gpu_memory = profile_result.after_profile.free_memory free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
@ -254,7 +276,7 @@ class Worker(WorkerBase):
"release GPU memory while vLLM is profiling during initialization. " "release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or " "To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container.") "isolate vLLM in its own container.")
available_kv_cache_memory = self.requested_memory \ self.available_kv_cache_memory_bytes = self.requested_memory \
- profile_result.non_kv_cache_memory - profile_result.non_kv_cache_memory
unrequested_memory = self.init_snapshot.free_memory \ unrequested_memory = self.init_snapshot.free_memory \
@ -274,10 +296,10 @@ class Worker(WorkerBase):
) )
logger.debug(profile_result) logger.debug(profile_result)
logger.info("Available KV cache memory: %.2f GiB", logger.info("Available KV cache memory: %.2f GiB",
GiB(available_kv_cache_memory)) GiB(self.available_kv_cache_memory_bytes))
gc.collect() gc.collect()
return int(available_kv_cache_memory) return int(self.available_kv_cache_memory_bytes)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
@ -317,8 +339,56 @@ class Worker(WorkerBase):
# cuda graph capture. # cuda graph capture.
kernel_warmup(self) kernel_warmup(self)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() cuda_graph_memory_bytes = self.model_runner.capture_model()
if (self.cache_config.kv_cache_memory_bytes is None
and hasattr(self, "peak_activation_memory")):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
non_kv_cache_memory = (self.model_runner.model_memory_usage +
self.peak_activation_memory +
self.non_torch_memory +
cuda_graph_memory_bytes)
kv_cache_memory_bytes_to_gpu_limit = (
self.init_snapshot.free_memory - non_kv_cache_memory -
redundancy_buffer_memory)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory) - non_kv_cache_memory -
redundancy_buffer_memory)
msg = (
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
f"requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
logger.info(msg)
# Warm up sampler and preallocate memory buffer for logits and other # Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory # sampling related tensors of max possible shape to avoid memory

View File

@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return self.lora_manager.list_adapters() return self.lora_manager.list_adapters()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int:
"""Cuda graph capture a model. """Cuda graph capture a model and return cudagraph memory
consumption in bytes.
Note that CUDA graph's performance gain is negligible if number Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph of batched tokens are larger than 200. And since CUDA graph
@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# This usually takes < 10 seconds. # This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / GiB_bytes) elapsed_time, cuda_graph_size / GiB_bytes)
return cuda_graph_size
def _update_inputs_to_capture_for_enc_dec_model(self, def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str, capture_inputs: Dict[str,

View File

@ -229,6 +229,67 @@ class Worker(LocalOrDistributedWorkerBase):
self.model_runner.save_tensorized_model( self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, ) tensorizer_config=tensorizer_config, )
@torch.inference_mode()
def determine_available_kv_cache_memory(self,
total_gpu_memory: int) -> float:
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
GiB = lambda b: b / GiB_bytes
msg = (
f"Initial free memory "
f"{GiB(self.baseline_snapshot.free_memory):.2f} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly.")
logger.info(msg)
return self.cache_config.kv_cache_memory_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self.non_torch_memory = result.non_torch_increase
self.peak_activation_memory = result.torch_peak_increase
self._assert_memory_footprint_increased_during_profiling()
self.requested_memory = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
self.available_kv_cache_memory = (self.requested_memory -
result.non_kv_cache_memory)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
return self.available_kv_cache_memory
@torch.inference_mode() @torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many """Profiles the peak memory usage of the model to determine how many
@ -248,20 +309,8 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
available_kv_cache_memory = self.determine_available_kv_cache_memory(
# Execute a forward pass with dummy inputs to profile the memory usage total_gpu_memory)
# of the model.
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()
self._assert_memory_footprint_increased_during_profiling()
memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory)
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
@ -276,23 +325,6 @@ class Worker(LocalOrDistributedWorkerBase):
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
"the current vLLM instance can use "
"total_gpu_memory "
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
logger.info(msg)
# Final cleanup # Final cleanup
gc.collect() gc.collect()
@ -382,8 +414,58 @@ class Worker(LocalOrDistributedWorkerBase):
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size) self.model_runner._dummy_run(size)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache) cuda_graph_memory_bytes = self.model_runner.capture_model(
self.gpu_cache)
if (self.cache_config.kv_cache_memory_bytes is None
and hasattr(self, "peak_activation_memory")):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
non_kv_cache_memory = (self.model_runner.model_memory_usage +
self.peak_activation_memory +
self.non_torch_memory +
cuda_graph_memory_bytes)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
kv_cache_memory_bytes_to_gpu_limit = (
self.baseline_snapshot.free_memory - non_kv_cache_memory -
redundancy_buffer_memory)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory) - non_kv_cache_memory -
redundancy_buffer_memory)
msg = (
f"Free memory on device "
f"({GiB(self.baseline_snapshot.free_memory)}/"
f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
f"requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{int(self.available_kv_cache_memory)} bytes.")
logger.info(msg)
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)