diff --git a/vllm/config/cache.py b/vllm/config/cache.py index bf85aad452d0..4c4e39c37ee5 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -113,6 +113,15 @@ class CacheConfig: necessary for implementing this optimization in some models (e.g. Gemma3n) """ + kv_cache_memory_bytes: Optional[int] = None + """Size of KV Cache per GPU in bytes. By default, this is set to None + and vllm can automatically infer the kv cache size based on + gpu_memory_utilization. However, users may want to manually specify + the kv cache memory size. kv_cache_memory_bytes allows more fine-grain + control of how much memory gets used when compared with using + gpu_memory_memory_utilization. Note that kv_cache_memory_bytes + (when not-None) ignores gpu_memory_utilization""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index be456af4d19d..ba1543a8d5a3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers - if name in {"max_model_len", "max_num_batched_tokens"}: + human_readable_ints = { + "max_model_len", + "max_num_batched_tokens", + "kv_cache_memory_bytes", + } + if name in human_readable_ints: kwargs[name]["type"] = human_readable_int + kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float elif (contains_type(type_hints, dict) @@ -335,6 +341,7 @@ class EngineArgs: swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes max_num_batched_tokens: Optional[ int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills @@ -734,6 +741,8 @@ class EngineArgs: cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) cache_group.add_argument("--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--kv-cache-memory-bytes", + **cache_kwargs["kv_cache_memory_bytes"]) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) @@ -1174,6 +1183,7 @@ class EngineArgs: cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, + kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3462142d9fb9..c303d093f632 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -278,7 +278,8 @@ class LLMEngine: self.cache_config.block_size, "gpu_memory_utilization": self.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + self.cache_config.kv_cache_memory_bytes, # Quantization "quantization": self.model_config.quantization, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e6fd61ae1aad..a6174161f115 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -110,6 +110,14 @@ class LLM: values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of- memory (OOM) errors. + kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default, + this is set to None and vllm can automatically infer the kv cache + size based on gpu_memory_utilization. However, users may want to + manually specify the kv cache memory size. kv_cache_memory_bytes + allows more fine-grain control of how much memory gets used when + compared with using gpu_memory_memory_utilization. Note that + kv_cache_memory_bytes (when not-None) ignores + gpu_memory_utilization swap_space: The size (GiB) of CPU memory per GPU to use as swap space. This can be used for temporarily storing the states of the requests when their `best_of` sampling parameters are larger than 1. If all @@ -184,6 +192,7 @@ class LLM: hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, override_pooler_config: Optional[PoolerConfig] = None, + kv_cache_memory_bytes: Optional[int] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, logits_processors: Optional[list[Union[str, @@ -251,6 +260,7 @@ class LLM: tokenizer_revision=tokenizer_revision, seed=seed, gpu_memory_utilization=gpu_memory_utilization, + kv_cache_memory_bytes=kv_cache_memory_bytes, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index e166e64e1101..ea860ca10b3a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2791,7 +2791,10 @@ def memory_profiling( result.torch_peak_increase = diff_profile.torch_peak result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp - result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index e0c7d9094aa6..fd84b4a111f5 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -355,7 +355,8 @@ def report_usage_stats( vllm_config.cache_config.block_size, "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + vllm_config.cache_config.kv_cache_memory_bytes, # Quantization "quantization": vllm_config.model_config.quantization, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b785af96a9a..ebb18e81c38a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3041,12 +3041,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() - def capture_model(self) -> None: + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " "ensure `cudagraph_mode` was not manually set to `NONE`") - return + return 0 else: self.initialize_cudagraph_capture() @@ -3117,6 +3117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + return cuda_graph_size def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 726f59603437..37dd431fd68f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -231,18 +231,40 @@ class Worker(WorkerBase): You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + GiB = lambda b: b / GiB_bytes + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + msg = ( + f"Initial free memory {GiB(self.init_snapshot.free_memory)} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return kv_cache_memory_bytes + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( self.init_snapshot, - weights_memory=int( - self.model_runner.model_memory_usage)) as profile_result: + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: self.model_runner.profile_run() + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. @@ -254,7 +276,7 @@ class Worker(WorkerBase): "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " "isolate vLLM in its own container.") - available_kv_cache_memory = self.requested_memory \ + self.available_kv_cache_memory_bytes = self.requested_memory \ - profile_result.non_kv_cache_memory unrequested_memory = self.init_snapshot.free_memory \ @@ -274,10 +296,10 @@ class Worker(WorkerBase): ) logger.debug(profile_result) logger.info("Available KV cache memory: %.2f GiB", - GiB(available_kv_cache_memory)) + GiB(self.available_kv_cache_memory_bytes)) gc.collect() - return int(available_kv_cache_memory) + return int(self.available_kv_cache_memory_bytes) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -317,8 +339,56 @@ class Worker(WorkerBase): # cuda graph capture. kernel_warmup(self) + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model() + cuda_graph_memory_bytes = self.model_runner.capture_model() + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory_bytes)} bytes.") + + logger.info(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f05401fd0132..88f83c9dd7e6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): return self.lora_manager.list_adapters() @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: - """Cuda graph capture a model. + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int: + """Cuda graph capture a model and return cudagraph memory + consumption in bytes. Note that CUDA graph's performance gain is negligible if number of batched tokens are larger than 200. And since CUDA graph @@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / GiB_bytes) + return cuda_graph_size def _update_inputs_to_capture_for_enc_dec_model(self, capture_inputs: Dict[str, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 142b1afce8c3..670f256c0bf6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -229,6 +229,67 @@ class Worker(LocalOrDistributedWorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + @torch.inference_mode() + def determine_available_kv_cache_memory(self, + total_gpu_memory: int) -> float: + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + GiB = lambda b: b / GiB_bytes + msg = ( + f"Initial free memory " + f"{GiB(self.baseline_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return self.cache_config.kv_cache_memory_bytes + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.baseline_snapshot, + weights_memory=self.model_runner.model_memory_usage) as result: + self.model_runner.profile_run() + + self.non_torch_memory = result.non_torch_increase + self.peak_activation_memory = result.torch_peak_increase + + self._assert_memory_footprint_increased_during_profiling() + + self.requested_memory = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + + self.available_kv_cache_memory = (self.requested_memory - + result.non_kv_cache_memory) + + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) + return self.available_kv_cache_memory + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many @@ -248,20 +309,8 @@ class Worker(LocalOrDistributedWorkerBase): torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage) as result: - self.model_runner.profile_run() - - self._assert_memory_footprint_increased_during_profiling() - - memory_for_current_instance = total_gpu_memory * \ - self.cache_config.gpu_memory_utilization - available_kv_cache_memory = (memory_for_current_instance - - result.non_kv_cache_memory) + available_kv_cache_memory = self.determine_available_kv_cache_memory( + total_gpu_memory) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -276,23 +325,6 @@ class Worker(LocalOrDistributedWorkerBase): num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") - - logger.info(msg) # Final cleanup gc.collect() @@ -382,8 +414,58 @@ class Worker(LocalOrDistributedWorkerBase): for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) + + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + cuda_graph_memory_bytes = self.model_runner.capture_model( + self.gpu_cache) + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + kv_cache_memory_bytes_to_gpu_limit = ( + self.baseline_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.baseline_snapshot.free_memory)}/" + f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory)} bytes.") + logger.info(msg) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)