mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 22:24:39 +08:00
[USAGE] Improve error handling for weight initialization in Unquantized… (#20321)
Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com> Signed-off-by: Rafael Koike <koike.rafael@gmail.com>
This commit is contained in:
parent
740f0647b1
commit
b834b4cbf1
@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
USE_XFORMERS_OPS = None
|
||||
@ -225,9 +225,26 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
try:
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||
dtype=torch.float32)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error(
|
||||
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to initialize q/k/v range constants. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"kv cache.") from e
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
# yapf: enable
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -190,10 +191,27 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
# This method creates unquantized linear weights.
|
||||
# The weights are not quantized, and they are not sharded.
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
try:
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||
logger.debug("Allocated: %.2f GiB",
|
||||
torch.cuda.memory_allocated() / GiB_bytes)
|
||||
logger.debug("Reserved: %.2f GiB",
|
||||
torch.cuda.memory_reserved() / GiB_bytes)
|
||||
raise RuntimeError(
|
||||
"Failed to create unquantized linear weights. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"the weight.") from e
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user