Isotr0py 6ac5e06f7c
[Chore] Clean up pytorch helper functions in vllm.utils (#26908)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: isotr0py <2037008807@qq.com>
2025-10-18 09:48:22 -07:00

646 lines
22 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional
import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
super().__init__()
self.unquantized_modules = unquantized_modules or []
def __repr__(self) -> str:
return "GGUFConfig()"
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
return cls()
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
) -> torch.Tensor:
if qweight_type in IMATRIX_QUANT_TYPES:
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
else:
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0:
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y
def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
except AttributeError as error:
raise error
def _fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out
# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
out_hidden_states = torch.empty_like(x)
# unless we decent expert reuse we are better off running moe_vec kernel
if (
qweight_type2 in MMQ_QUANT_TYPES
and qweight_type in MMQ_QUANT_TYPES
and x.shape[0] > 64
):
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, BLOCK_SIZE, E
)
out = ops.ggml_moe_a8(
x,
w1,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
qweight_type,
N,
top_k,
num_tokens,
)
out = act(out)
out = ops.ggml_moe_a8(
out,
w2,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
qweight_type2,
w2.shape[1],
1,
num_tokens * top_k,
)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1)
)
ops.moe_sum(out, out_hidden_states)
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
out = act(out)
out = ops.ggml_moe_a8_vec(
out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1)
)
ops.moe_sum(out, out_hidden_states)
else:
logger.warning_once(
"There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. "
)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1,) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = w1[ii]
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out)
expert_down = w2[ii]
current_state = fused_mul_mat_gguf(
out, expert_down, qweight_type2
).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
out_hidden_states[tok] = current_hidden_state
return out_hidden_states
def _fused_moe_gguf_fake(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
except AttributeError as error:
raise error
def _apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert hidden_size == qweight.shape[1] // type_size * block_size
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
except AttributeError as error:
raise error
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.params_dtype = params_dtype
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
"shard_id": [],
"shard_id_map": {},
},
)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
qweight_type = Parameter(
torch.empty(len(output_partition_sizes), dtype=torch.uint8),
requires_grad=False,
)
set_weight_attrs(
qweight_type,
{
"is_gguf_weight_type": True,
"weight_type": 0,
"shard_weight_type": {},
"ignore_warning": True,
},
)
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)
def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
)
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)
def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}"
)
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros(
(concat_side, padded_side), dtype=dtype, device=qweight.device
)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
shard_id = layer.qweight.shard_id
if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight
result = []
for idx in shard_id:
start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(), qweight_type
)
)
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
},
)
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)
w13_qweight_type = Parameter(
torch.empty(1, dtype=torch.uint8), requires_grad=False
)
set_weight_attrs(
w13_qweight_type,
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
)
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)
tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
# gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight,
{
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
},
)
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)
w2_qweight_type = Parameter(
torch.empty(1, dtype=torch.uint8), requires_grad=False
)
set_weight_attrs(
w2_qweight_type,
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
)
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
return fused_moe_gguf(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights,
topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type,
activation,
)
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]
return apply_gguf_embedding(
x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
)
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: list[torch.Tensor]