mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 01:16:59 +08:00
[USAGE] Improve error handling for weight initialization in Unquantized… (#20321)
Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com> Signed-off-by: Rafael Koike <koike.rafael@gmail.com>
This commit is contained in:
parent
740f0647b1
commit
b834b4cbf1
@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
USE_XFORMERS_OPS = None
|
USE_XFORMERS_OPS = None
|
||||||
@ -225,9 +225,26 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
).parallel_config.pipeline_parallel_size)
|
).parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
try:
|
||||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||||
|
logger.debug("Allocated: %.2f GiB",
|
||||||
|
torch.cuda.memory_allocated() / GiB_bytes)
|
||||||
|
logger.debug("Reserved: %.2f GiB",
|
||||||
|
torch.cuda.memory_reserved() / GiB_bytes)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to initialize q/k/v range constants. "
|
||||||
|
"This may be caused by insufficient memory to allocate "
|
||||||
|
"kv cache.") from e
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import GiB_bytes
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -190,10 +191,27 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
output_partition_sizes: list[int], input_size: int,
|
output_partition_sizes: list[int], input_size: int,
|
||||||
output_size: int, params_dtype: torch.dtype,
|
output_size: int, params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
# This method creates unquantized linear weights.
|
||||||
input_size_per_partition,
|
# The weights are not quantized, and they are not sharded.
|
||||||
dtype=params_dtype),
|
# The amount of memory allocated for the weights is
|
||||||
requires_grad=False)
|
# sum(output_partition_sizes) * input_size_per_partition.
|
||||||
|
try:
|
||||||
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||||
|
logger.debug("Allocated: %.2f GiB",
|
||||||
|
torch.cuda.memory_allocated() / GiB_bytes)
|
||||||
|
logger.debug("Reserved: %.2f GiB",
|
||||||
|
torch.cuda.memory_reserved() / GiB_bytes)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to create unquantized linear weights. "
|
||||||
|
"This may be caused by insufficient memory to allocate "
|
||||||
|
"the weight.") from e
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user