[USAGE] Improve error handling for weight initialization in Unquantized… (#20321)

Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com>
Signed-off-by: Rafael Koike <koike.rafael@gmail.com>
This commit is contained in:
Rafael Marcelino Koike 2025-09-15 12:45:49 -04:00 committed by GitHub
parent 740f0647b1
commit b834b4cbf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 8 deletions

View File

@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
@ -225,9 +225,26 @@ class Attention(nn.Module, AttentionLayerBase):
).parallel_config.pipeline_parallel_size)
]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
dtype=torch.float32)
except torch.cuda.OutOfMemoryError as e:
logger.error(
"Failed to initialize attention q/k/v range constants: %s", e)
if torch.cuda.is_available():
logger.debug("CUDA device: %s", torch.cuda.current_device())
logger.debug("Allocated: %.2f GiB",
torch.cuda.memory_allocated() / GiB_bytes)
logger.debug("Reserved: %.2f GiB",
torch.cuda.memory_reserved() / GiB_bytes)
raise RuntimeError(
"Failed to initialize q/k/v range constants. "
"This may be caused by insufficient memory to allocate "
"kv cache.") from e
def forward(
self,

View File

@ -29,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
logger = init_logger(__name__)
@ -190,10 +191,27 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
# This method creates unquantized linear weights.
# The weights are not quantized, and they are not sharded.
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
try:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
except torch.cuda.OutOfMemoryError as e:
logger.error("Failed to create unquantized linear weights: %s", e)
if torch.cuda.is_available():
logger.debug("CUDA device: %s", torch.cuda.current_device())
logger.debug("Allocated: %.2f GiB",
torch.cuda.memory_allocated() / GiB_bytes)
logger.debug("Reserved: %.2f GiB",
torch.cuda.memory_reserved() / GiB_bytes)
raise RuntimeError(
"Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate "
"the weight.") from e
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)