mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 03:07:04 +08:00
[Bugfix] fix confusing OOM errors during v1 init (#28051)
Signed-off-by: Shivam <shivamprasad91@gmail.com> Signed-off-by: shivampr <shivampr.dev@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
166ac3c94d
commit
8580919ac3
54
tests/v1/engine/test_init_error_messaging.py
Normal file
54
tests/v1/engine/test_init_error_messaging.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
def test_kv_cache_oom_no_memory():
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 0)
|
||||
|
||||
|
||||
def test_kv_cache_oom_insufficient_memory(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
config.cache_config.block_size = 16
|
||||
config.parallel_config.tensor_parallel_size = 1
|
||||
config.parallel_config.pipeline_parallel_size = 1
|
||||
config.parallel_config.decode_context_parallel_size = 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
|
||||
lambda c, s: 100 * 1024**3, # 100 GiB
|
||||
)
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB
|
||||
@ -687,7 +687,9 @@ def check_enough_kv_cache_memory(
|
||||
raise ValueError(
|
||||
"No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine."
|
||||
"initializing the engine. "
|
||||
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
@ -711,8 +713,10 @@ def check_enough_kv_cache_memory(
|
||||
f"cache is needed, which is larger than the available KV cache "
|
||||
f"memory ({available_memory / GiB_bytes:.2f} GiB). "
|
||||
f"{estimated_msg} "
|
||||
f"Try increasing `gpu_memory_utilization` or decreasing "
|
||||
f"`max_model_len` when initializing the engine."
|
||||
f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
|
||||
f"when initializing the engine. "
|
||||
f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
|
||||
f"for more details."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -3571,74 +3571,89 @@ class GPUModelRunner(
|
||||
if self.parallel_config.enable_eplb:
|
||||
self.eplb_state = EplbState(self.parallel_config, self.device)
|
||||
eplb_models = 0
|
||||
with DeviceMemoryProfiler() as m:
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model, self.vllm_config, self.device
|
||||
|
||||
try:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config, model_config=self.model_config
|
||||
)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info_once("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if (
|
||||
hasattr(self.drafter, "model")
|
||||
and is_mixture_of_experts(self.drafter.model)
|
||||
and self.parallel_config.enable_eplb
|
||||
):
|
||||
spec_config = self.vllm_config.speculative_config
|
||||
assert spec_config is not None
|
||||
assert spec_config.draft_model_config is not None
|
||||
logger.info_once(
|
||||
"EPLB is enabled for drafter model %s.",
|
||||
spec_config.draft_model_config.model,
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model, self.vllm_config, self.device
|
||||
)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info_once("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if (
|
||||
hasattr(self.drafter, "model")
|
||||
and is_mixture_of_experts(self.drafter.model)
|
||||
and self.parallel_config.enable_eplb
|
||||
):
|
||||
spec_config = self.vllm_config.speculative_config
|
||||
assert spec_config is not None
|
||||
assert spec_config.draft_model_config is not None
|
||||
logger.info_once(
|
||||
"EPLB is enabled for drafter model %s.",
|
||||
spec_config.draft_model_config.model,
|
||||
)
|
||||
|
||||
global_expert_load = (
|
||||
global_expert_loads[eplb_models]
|
||||
if global_expert_loads
|
||||
else None
|
||||
)
|
||||
old_global_expert_indices = (
|
||||
old_global_expert_indices_per_model[eplb_models]
|
||||
if old_global_expert_indices_per_model
|
||||
else None
|
||||
)
|
||||
if self.eplb_state is None:
|
||||
self.eplb_state = EplbState(self.parallel_config, self.device)
|
||||
self.eplb_state.add_model(
|
||||
self.drafter.model,
|
||||
spec_config.draft_model_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
eplb_models += 1
|
||||
global_expert_load = (
|
||||
global_expert_loads[eplb_models]
|
||||
if global_expert_loads
|
||||
else None
|
||||
)
|
||||
old_global_expert_indices = (
|
||||
old_global_expert_indices_per_model[eplb_models]
|
||||
if old_global_expert_indices_per_model
|
||||
else None
|
||||
)
|
||||
if self.eplb_state is None:
|
||||
self.eplb_state = EplbState(
|
||||
self.parallel_config, self.device
|
||||
)
|
||||
self.eplb_state.add_model(
|
||||
self.drafter.model,
|
||||
spec_config.draft_model_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
eplb_models += 1
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
if not supports_eagle3(self.get_model()):
|
||||
raise RuntimeError(
|
||||
"Model does not support EAGLE3 interface but "
|
||||
"aux_hidden_state_outputs was requested"
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
if not supports_eagle3(self.get_model()):
|
||||
raise RuntimeError(
|
||||
"Model does not support EAGLE3 interface but "
|
||||
"aux_hidden_state_outputs was requested"
|
||||
)
|
||||
|
||||
# Try to get auxiliary layers from speculative config,
|
||||
# otherwise use model's default layers
|
||||
aux_layers = self._get_eagle3_aux_layers_from_config()
|
||||
if aux_layers:
|
||||
logger.info(
|
||||
"Using auxiliary layers from speculative config: %s",
|
||||
aux_layers,
|
||||
)
|
||||
else:
|
||||
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
||||
# Try to get auxiliary layers from speculative config,
|
||||
# otherwise use model's default layers
|
||||
aux_layers = self._get_eagle3_aux_layers_from_config()
|
||||
if aux_layers:
|
||||
logger.info(
|
||||
"Using auxiliary layers from speculative config: %s",
|
||||
aux_layers,
|
||||
)
|
||||
else:
|
||||
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
||||
|
||||
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||
time_after_load = time.perf_counter()
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
msg = (
|
||||
"Failed to load model - not enough GPU memory. "
|
||||
"Try lowering --gpu-memory-utilization to free memory for weights, "
|
||||
"increasing --tensor-parallel-size, or using --quantization. "
|
||||
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
|
||||
"for more tips."
|
||||
)
|
||||
combined_msg = f"{msg} (original error: {e})"
|
||||
logger.error(combined_msg)
|
||||
raise e
|
||||
logger.info_once(
|
||||
"Model loading took %.4f GiB memory and %.6f seconds",
|
||||
self.model_memory_usage / GiB_bytes,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user