mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
Allow users to specify kv cache memory size (#21489)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
fd1ce98cdd
commit
94e6b2d55f
@ -113,6 +113,15 @@ class CacheConfig:
|
||||
necessary for implementing this optimization in some models (e.g. Gemma3n)
|
||||
"""
|
||||
|
||||
kv_cache_memory_bytes: Optional[int] = None
|
||||
"""Size of KV Cache per GPU in bytes. By default, this is set to None
|
||||
and vllm can automatically infer the kv cache size based on
|
||||
gpu_memory_utilization. However, users may want to manually specify
|
||||
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
|
||||
control of how much memory gets used when compared with using
|
||||
gpu_memory_memory_utilization. Note that kv_cache_memory_bytes
|
||||
(when not-None) ignores gpu_memory_utilization"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
elif contains_type(type_hints, int):
|
||||
kwargs[name]["type"] = int
|
||||
# Special case for large integers
|
||||
if name in {"max_model_len", "max_num_batched_tokens"}:
|
||||
human_readable_ints = {
|
||||
"max_model_len",
|
||||
"max_num_batched_tokens",
|
||||
"kv_cache_memory_bytes",
|
||||
}
|
||||
if name in human_readable_ints:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif (contains_type(type_hints, dict)
|
||||
@ -335,6 +341,7 @@ class EngineArgs:
|
||||
swap_space: float = CacheConfig.swap_space
|
||||
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
|
||||
kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
|
||||
max_num_batched_tokens: Optional[
|
||||
int] = SchedulerConfig.max_num_batched_tokens
|
||||
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
||||
@ -734,6 +741,8 @@ class EngineArgs:
|
||||
cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
|
||||
cache_group.add_argument("--gpu-memory-utilization",
|
||||
**cache_kwargs["gpu_memory_utilization"])
|
||||
cache_group.add_argument("--kv-cache-memory-bytes",
|
||||
**cache_kwargs["kv_cache_memory_bytes"])
|
||||
cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
|
||||
cache_group.add_argument("--kv-cache-dtype",
|
||||
**cache_kwargs["cache_dtype"])
|
||||
@ -1174,6 +1183,7 @@ class EngineArgs:
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
|
||||
swap_space=self.swap_space,
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
is_attention_free=model_config.is_attention_free,
|
||||
|
||||
@ -278,7 +278,8 @@ class LLMEngine:
|
||||
self.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
|
||||
"kv_cache_memory_bytes":
|
||||
self.cache_config.kv_cache_memory_bytes,
|
||||
# Quantization
|
||||
"quantization":
|
||||
self.model_config.quantization,
|
||||
|
||||
@ -110,6 +110,14 @@ class LLM:
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
|
||||
this is set to None and vllm can automatically infer the kv cache
|
||||
size based on gpu_memory_utilization. However, users may want to
|
||||
manually specify the kv cache memory size. kv_cache_memory_bytes
|
||||
allows more fine-grain control of how much memory gets used when
|
||||
compared with using gpu_memory_memory_utilization. Note that
|
||||
kv_cache_memory_bytes (when not-None) ignores
|
||||
gpu_memory_utilization
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
@ -184,6 +192,7 @@ class LLM:
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
kv_cache_memory_bytes: Optional[int] = None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any],
|
||||
CompilationConfig]] = None,
|
||||
logits_processors: Optional[list[Union[str,
|
||||
@ -251,6 +260,7 @@ class LLM:
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
kv_cache_memory_bytes=kv_cache_memory_bytes,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
|
||||
@ -2791,7 +2791,10 @@ def memory_profiling(
|
||||
result.torch_peak_increase = diff_profile.torch_peak
|
||||
result.non_torch_increase = diff_from_create.non_torch_memory
|
||||
result.profile_time = diff_profile.timestamp
|
||||
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
|
||||
|
||||
non_torch_memory = result.non_torch_increase
|
||||
peak_activation_memory = result.torch_peak_increase
|
||||
result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||
|
||||
@ -355,7 +355,8 @@ def report_usage_stats(
|
||||
vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
vllm_config.cache_config.gpu_memory_utilization,
|
||||
|
||||
"kv_cache_memory_bytes":
|
||||
vllm_config.cache_config.kv_cache_memory_bytes,
|
||||
# Quantization
|
||||
"quantization":
|
||||
vllm_config.model_config.quantization,
|
||||
|
||||
@ -3041,12 +3041,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
def capture_model(self) -> int:
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"ensure `cudagraph_mode` was not manually set to `NONE`")
|
||||
return
|
||||
return 0
|
||||
else:
|
||||
self.initialize_cudagraph_capture()
|
||||
|
||||
@ -3117,6 +3117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
return cuda_graph_size
|
||||
|
||||
def _capture_cudagraphs(self, compilation_cases: list[int],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
|
||||
@ -231,18 +231,40 @@ class Worker(WorkerBase):
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
msg = (
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly.")
|
||||
logger.info(msg)
|
||||
return kv_cache_memory_bytes
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(
|
||||
self.model_runner.model_memory_usage)) as profile_result:
|
||||
weights_memory=int(self.model_runner.model_memory_usage),
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = profile_result.non_torch_increase
|
||||
self.peak_activation_memory = profile_result.torch_peak_increase
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
@ -254,7 +276,7 @@ class Worker(WorkerBase):
|
||||
"release GPU memory while vLLM is profiling during initialization. "
|
||||
"To fix this, ensure consistent GPU memory allocation or "
|
||||
"isolate vLLM in its own container.")
|
||||
available_kv_cache_memory = self.requested_memory \
|
||||
self.available_kv_cache_memory_bytes = self.requested_memory \
|
||||
- profile_result.non_kv_cache_memory
|
||||
|
||||
unrequested_memory = self.init_snapshot.free_memory \
|
||||
@ -274,10 +296,10 @@ class Worker(WorkerBase):
|
||||
)
|
||||
logger.debug(profile_result)
|
||||
logger.info("Available KV cache memory: %.2f GiB",
|
||||
GiB(available_kv_cache_memory))
|
||||
GiB(self.available_kv_cache_memory_bytes))
|
||||
gc.collect()
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
@ -317,8 +339,56 @@ class Worker(WorkerBase):
|
||||
# cuda graph capture.
|
||||
kernel_warmup(self)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
||||
|
||||
if (self.cache_config.kv_cache_memory_bytes is None
|
||||
and hasattr(self, "peak_activation_memory")):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
||||
self.peak_activation_memory +
|
||||
self.non_torch_memory +
|
||||
cuda_graph_memory_bytes)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.init_snapshot.free_memory - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory) - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
|
||||
f"requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
|
||||
|
||||
logger.info(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
|
||||
@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int:
|
||||
"""Cuda graph capture a model and return cudagraph memory
|
||||
consumption in bytes.
|
||||
|
||||
Note that CUDA graph's performance gain is negligible if number
|
||||
of batched tokens are larger than 200. And since CUDA graph
|
||||
@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# This usually takes < 10 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / GiB_bytes)
|
||||
return cuda_graph_size
|
||||
|
||||
def _update_inputs_to_capture_for_enc_dec_model(self,
|
||||
capture_inputs: Dict[str,
|
||||
|
||||
@ -229,6 +229,67 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.model_runner.save_tensorized_model(
|
||||
tensorizer_config=tensorizer_config, )
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_kv_cache_memory(self,
|
||||
total_gpu_memory: int) -> float:
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
msg = (
|
||||
f"Initial free memory "
|
||||
f"{GiB(self.baseline_snapshot.free_memory):.2f} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly.")
|
||||
logger.info(msg)
|
||||
return self.cache_config.kv_cache_memory_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.baseline_snapshot,
|
||||
weights_memory=self.model_runner.model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = result.non_torch_increase
|
||||
self.peak_activation_memory = result.torch_peak_increase
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
self.requested_memory = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
|
||||
self.available_kv_cache_memory = (self.requested_memory -
|
||||
result.non_kv_cache_memory)
|
||||
|
||||
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
logger.info(msg)
|
||||
return self.available_kv_cache_memory
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
@ -248,20 +309,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.baseline_snapshot,
|
||||
weights_memory=self.model_runner.model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
available_kv_cache_memory = (memory_for_current_instance -
|
||||
result.non_kv_cache_memory)
|
||||
available_kv_cache_memory = self.determine_available_kv_cache_memory(
|
||||
total_gpu_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
@ -276,23 +325,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n"
|
||||
"the current vLLM instance can use "
|
||||
"total_gpu_memory "
|
||||
f"({(total_gpu_memory / GiB_bytes):.2f}GiB)"
|
||||
" x gpu_memory_utilization "
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
logger.info(msg)
|
||||
# Final cleanup
|
||||
gc.collect()
|
||||
|
||||
@ -382,8 +414,58 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model(self.gpu_cache)
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model(
|
||||
self.gpu_cache)
|
||||
|
||||
if (self.cache_config.kv_cache_memory_bytes is None
|
||||
and hasattr(self, "peak_activation_memory")):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
||||
self.peak_activation_memory +
|
||||
self.non_torch_memory +
|
||||
cuda_graph_memory_bytes)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.baseline_snapshot.free_memory - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory) - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.baseline_snapshot.free_memory)}/"
|
||||
f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
|
||||
f"requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{int(self.available_kv_cache_memory)} bytes.")
|
||||
logger.info(msg)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user