diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 44cb2c7c6b642..22dc6dcbc8d62 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -225,9 +225,26 @@ class Attention(nn.Module, AttentionLayerBase): ).parallel_config.pipeline_parallel_size) ] - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + try: + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, + dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, + dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, + dtype=torch.float32) + except torch.cuda.OutOfMemoryError as e: + logger.error( + "Failed to initialize attention q/k/v range constants: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to initialize q/k/v range constants. " + "This may be caused by insufficient memory to allocate " + "kv cache.") from e def forward( self, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 773dfeae25d93..cd05136520977 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -29,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, # yapf: enable from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import GiB_bytes logger = init_logger(__name__) @@ -190,10 +191,27 @@ class UnquantizedLinearMethod(LinearMethodBase): output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + try: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to create unquantized linear weights: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to create unquantized linear weights. " + "This may be caused by insufficient memory to allocate " + "the weight.") from e set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs)