mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: isotr0py <2037008807@qq.com>
646 lines
22 KiB
Python
646 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional
|
|
|
|
import gguf
|
|
import torch
|
|
from gguf import GGMLQuantizationType as WeightType
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEConfig,
|
|
FusedMoEQuantConfig,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
|
from vllm.model_executor.layers.linear import (
|
|
LinearBase,
|
|
LinearMethodBase,
|
|
UnquantizedLinearMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class GGUFConfig(QuantizationConfig):
|
|
"""Config class for GGUF."""
|
|
|
|
def __init__(self, unquantized_modules: list[str] | None = None) -> None:
|
|
super().__init__()
|
|
self.unquantized_modules = unquantized_modules or []
|
|
|
|
def __repr__(self) -> str:
|
|
return "GGUFConfig()"
|
|
|
|
def get_name(self) -> QuantizationMethods:
|
|
return "gguf"
|
|
|
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
return [torch.half, torch.bfloat16, torch.float32]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 60
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return [] # no extra configs.
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
|
return cls()
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional["QuantizeMethodBase"]:
|
|
if isinstance(layer, LinearBase):
|
|
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
|
return UnquantizedLinearMethod()
|
|
return GGUFLinearMethod(self)
|
|
elif isinstance(layer, VocabParallelEmbedding):
|
|
return GGUFEmbeddingMethod(self)
|
|
elif isinstance(layer, FusedMoE):
|
|
return GGUFMoEMethod(self, layer.moe_config)
|
|
return None
|
|
|
|
|
|
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
|
return any(module_name in prefix for module_name in unquantized_modules)
|
|
|
|
|
|
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
|
STANDARD_QUANT_TYPES = {
|
|
WeightType.Q4_0,
|
|
WeightType.Q4_1,
|
|
WeightType.Q5_0,
|
|
WeightType.Q5_1,
|
|
WeightType.Q8_0,
|
|
WeightType.Q8_1,
|
|
}
|
|
KQUANT_TYPES = {
|
|
WeightType.Q2_K,
|
|
WeightType.Q3_K,
|
|
WeightType.Q4_K,
|
|
WeightType.Q5_K,
|
|
WeightType.Q6_K,
|
|
}
|
|
IMATRIX_QUANT_TYPES = {
|
|
WeightType.IQ1_M,
|
|
WeightType.IQ1_S,
|
|
WeightType.IQ2_XXS,
|
|
WeightType.IQ2_XS,
|
|
WeightType.IQ2_S,
|
|
WeightType.IQ3_XXS,
|
|
WeightType.IQ3_S,
|
|
WeightType.IQ4_XS,
|
|
WeightType.IQ4_NL,
|
|
}
|
|
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
|
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
|
# MMQ kernel for I-Matrix quantization.
|
|
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
|
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
|
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
|
|
|
|
|
def _fused_mul_mat_gguf(
|
|
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
|
|
) -> torch.Tensor:
|
|
if qweight_type in IMATRIX_QUANT_TYPES:
|
|
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
|
|
else:
|
|
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
|
|
# HACK: when doing chunked prefill we don't generate output tokens
|
|
# so input to logits generator is empty which causes invalid parameter
|
|
if x.shape[0] == 0:
|
|
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
|
|
# there is no need to call any kernel for fp16/bf16
|
|
if qweight_type in UNQUANTIZED_TYPES:
|
|
return x @ qweight.T
|
|
# enable MMVQ in contiguous batching with batch_size=1
|
|
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
|
|
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
|
# Use MMQ Kernel if it's available (standard + k-quants)
|
|
elif qweight_type in MMQ_QUANT_TYPES:
|
|
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
|
# If there is no available MMQ kernel, fallback to dequantize
|
|
elif qweight_type in DEQUANT_TYPES:
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
|
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
|
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
|
y = x @ weight.T
|
|
else:
|
|
# Raise an error if the quantization type is not supported.
|
|
# Might be useful if llama.cpp adds a new quantization type.
|
|
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
|
qweight_type = WeightType(qweight_type)
|
|
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
|
return y
|
|
|
|
|
|
def _fused_mul_mat_gguf_fake(
|
|
x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
qweight_type: int,
|
|
) -> torch.Tensor:
|
|
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="_fused_mul_mat_gguf",
|
|
op_func=_fused_mul_mat_gguf,
|
|
fake_impl=_fused_mul_mat_gguf_fake,
|
|
)
|
|
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
|
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
def _fused_moe_gguf(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
qweight_type: int,
|
|
qweight_type2: int,
|
|
activation: str,
|
|
) -> torch.Tensor:
|
|
def act(x: torch.Tensor):
|
|
d = x.shape[-1] // 2
|
|
output_shape = x.shape[:-1] + (d,)
|
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
|
if activation == "silu":
|
|
torch.ops._C.silu_and_mul(out, x)
|
|
elif activation == "gelu":
|
|
torch.ops._C.gelu_and_mul(out, x)
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {activation}")
|
|
return out
|
|
|
|
# lazy import to avoid triggering triton import in CPU backend
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
|
|
|
out_hidden_states = torch.empty_like(x)
|
|
# unless we decent expert reuse we are better off running moe_vec kernel
|
|
if (
|
|
qweight_type2 in MMQ_QUANT_TYPES
|
|
and qweight_type in MMQ_QUANT_TYPES
|
|
and x.shape[0] > 64
|
|
):
|
|
num_tokens, _ = x.shape
|
|
E, N, _ = w1.shape
|
|
top_k = topk_ids.shape[1]
|
|
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
|
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
topk_ids, BLOCK_SIZE, E
|
|
)
|
|
out = ops.ggml_moe_a8(
|
|
x,
|
|
w1,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
qweight_type,
|
|
N,
|
|
top_k,
|
|
num_tokens,
|
|
)
|
|
out = act(out)
|
|
out = ops.ggml_moe_a8(
|
|
out,
|
|
w2,
|
|
sorted_token_ids,
|
|
expert_ids,
|
|
num_tokens_post_padded,
|
|
qweight_type2,
|
|
w2.shape[1],
|
|
1,
|
|
num_tokens * top_k,
|
|
)
|
|
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
|
topk_weights.view(num_tokens, top_k, 1)
|
|
)
|
|
ops.moe_sum(out, out_hidden_states)
|
|
elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES:
|
|
num_tokens, _ = x.shape
|
|
E, N, _ = w1.shape
|
|
top_k = topk_ids.shape[1]
|
|
|
|
out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens)
|
|
out = act(out)
|
|
|
|
out = ops.ggml_moe_a8_vec(
|
|
out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k
|
|
)
|
|
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
|
|
topk_weights.view(num_tokens, top_k, 1)
|
|
)
|
|
ops.moe_sum(out, out_hidden_states)
|
|
else:
|
|
logger.warning_once(
|
|
"There is no support for fast MoE kernel "
|
|
"for current quantization method. "
|
|
"Falling back to slow implementation. "
|
|
)
|
|
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
|
inp = x[tok].reshape((1,) + x.shape[1:])
|
|
current_hidden_state = None
|
|
for ww, ii in zip(w, idx):
|
|
expert_up = w1[ii]
|
|
|
|
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
|
|
out = act(out)
|
|
|
|
expert_down = w2[ii]
|
|
current_state = fused_mul_mat_gguf(
|
|
out, expert_down, qweight_type2
|
|
).mul_(ww)
|
|
if current_hidden_state is None:
|
|
current_hidden_state = current_state
|
|
else:
|
|
current_hidden_state.add_(current_state)
|
|
out_hidden_states[tok] = current_hidden_state
|
|
return out_hidden_states
|
|
|
|
|
|
def _fused_moe_gguf_fake(
|
|
x: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
qweight_type: int,
|
|
qweight_type2: int,
|
|
activation: str,
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(x)
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="_fused_moe_gguf",
|
|
op_func=_fused_moe_gguf,
|
|
fake_impl=_fused_moe_gguf_fake,
|
|
)
|
|
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
|
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
def _apply_gguf_embedding(
|
|
x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
qweight_type: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype | None = None,
|
|
) -> torch.Tensor:
|
|
if qweight_type in UNQUANTIZED_TYPES:
|
|
return torch.embedding(qweight, x)
|
|
elif qweight_type in DEQUANT_TYPES:
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
|
x_flat = x.flatten()
|
|
assert hidden_size == qweight.shape[1] // type_size * block_size
|
|
quant = torch.index_select(qweight, dim=0, index=x_flat)
|
|
dequant = ops.ggml_dequantize(
|
|
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
|
|
)
|
|
return dequant.view(*x.shape, hidden_size)
|
|
else:
|
|
qweight_type = WeightType(qweight_type)
|
|
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
|
|
|
|
|
def _apply_gguf_embedding_fake(
|
|
x: torch.Tensor,
|
|
qweight: torch.Tensor,
|
|
qweight_type: int,
|
|
hidden_size: int,
|
|
dtype: torch.dtype | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="_apply_gguf_embedding",
|
|
op_func=_apply_gguf_embedding,
|
|
fake_impl=_apply_gguf_embedding_fake,
|
|
)
|
|
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
|
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
class GGUFLinearMethod(LinearMethodBase):
|
|
"""Linear method for GGUF.
|
|
|
|
Args:
|
|
quant_config: The GGUF quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: GGUFConfig):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
self.params_dtype = params_dtype
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
|
|
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
|
qweight = GGUFUninitializedParameter(requires_grad=False)
|
|
set_weight_attrs(
|
|
qweight,
|
|
{
|
|
"input_dim": 1,
|
|
"output_dim": 0,
|
|
"tensor_shape": tensor_shape,
|
|
"is_gguf_weight": True,
|
|
"data_container": [],
|
|
"shard_id": [],
|
|
"shard_id_map": {},
|
|
},
|
|
)
|
|
set_weight_attrs(qweight, extra_weight_attrs)
|
|
layer.register_parameter("qweight", qweight)
|
|
|
|
qweight_type = Parameter(
|
|
torch.empty(len(output_partition_sizes), dtype=torch.uint8),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(
|
|
qweight_type,
|
|
{
|
|
"is_gguf_weight_type": True,
|
|
"weight_type": 0,
|
|
"shard_weight_type": {},
|
|
"ignore_warning": True,
|
|
},
|
|
)
|
|
set_weight_attrs(qweight_type, extra_weight_attrs)
|
|
layer.register_parameter("qweight_type", qweight_type)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
qweight_type = layer.qweight_type.weight_type
|
|
if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES):
|
|
qweight_type = WeightType(qweight_type)
|
|
raise ValueError(
|
|
f"Unsupported GGUF quantization type {qweight_type} in layer {layer}."
|
|
)
|
|
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
|
|
# materialize the padded weight parameter for CUDA Graph compatibility.
|
|
self._create_padded_weight_param(layer)
|
|
|
|
def _create_padded_weight_param(self, layer: torch.nn.Module):
|
|
"""Create padded weight parameter for GGUF MergedLinear layer."""
|
|
qweight = layer.qweight
|
|
shard_id_map = qweight.shard_id_map
|
|
shard_id = qweight.shard_id
|
|
if len(data_container := qweight.data_container) > 1:
|
|
dtype = {data.dtype for data in data_container}
|
|
assert len(dtype) == 1, ValueError(
|
|
f"Data container has mixed dtypes: {dtype}"
|
|
)
|
|
dtype = next(iter(dtype))
|
|
# concat dim0 and pad dim1
|
|
padded_side = max(x.size(1) for x in data_container)
|
|
concat_side = sum(x.size(0) for x in data_container)
|
|
# Pad the quantized weights to dense tensor, and create a map
|
|
# with the location of each shard in the padded tensor.
|
|
padded_data = torch.zeros(
|
|
(concat_side, padded_side), dtype=dtype, device=qweight.device
|
|
)
|
|
# (dim0_start, dim0_end, dim1_size)
|
|
shard_offset_map = dict[str, tuple[int, int, int]]()
|
|
for idx in shard_id:
|
|
id_in_container = shard_id_map[idx]
|
|
start = sum(x.size(0) for x in data_container[:id_in_container])
|
|
end = start + data_container[id_in_container].size(0)
|
|
size = data_container[id_in_container].size(1)
|
|
padded_data[start:end, :size] = data_container[id_in_container]
|
|
shard_offset_map[idx] = (start, end, size)
|
|
qweight.data_container.clear()
|
|
padded_param = Parameter(padded_data, requires_grad=False)
|
|
set_weight_attrs(padded_param, vars(qweight))
|
|
set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map})
|
|
layer.register_parameter("qweight", padded_param)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
shard_id = layer.qweight.shard_id
|
|
|
|
if shard_id:
|
|
# dequantize shard weights respectively
|
|
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
|
qweight = layer.qweight
|
|
result = []
|
|
for idx in shard_id:
|
|
start, end, offset = layer.qweight.shard_offset_map[idx]
|
|
qweight_type = layer.qweight_type.shard_weight_type[idx]
|
|
result.append(
|
|
fused_mul_mat_gguf(
|
|
x, qweight[start:end, :offset].contiguous(), qweight_type
|
|
)
|
|
)
|
|
out = torch.cat(result, axis=1)
|
|
else:
|
|
qweight = layer.qweight
|
|
qweight_type = layer.qweight_type.weight_type
|
|
out = fused_mul_mat_gguf(x, qweight, qweight_type)
|
|
if bias is not None:
|
|
out.add_(bias)
|
|
return out
|
|
|
|
|
|
class GGUFMoEMethod(FusedMoEMethodBase):
|
|
"""MoE method for GGUF.
|
|
|
|
Args:
|
|
quant_config: The GGUF quantization config.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
quant_config: GGUFConfig,
|
|
moe: FusedMoEConfig,
|
|
):
|
|
super().__init__(moe)
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
|
# gate up proj
|
|
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
|
set_weight_attrs(
|
|
w13_qweight,
|
|
{
|
|
"input_dim": 1,
|
|
"output_dim": 0,
|
|
"tensor_shape": tensor_shape,
|
|
"is_gguf_weight": True,
|
|
"data_container": [],
|
|
},
|
|
)
|
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
|
layer.register_parameter("w13_qweight", w13_qweight)
|
|
|
|
w13_qweight_type = Parameter(
|
|
torch.empty(1, dtype=torch.uint8), requires_grad=False
|
|
)
|
|
set_weight_attrs(
|
|
w13_qweight_type,
|
|
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
|
|
)
|
|
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
|
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
|
|
|
tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size)
|
|
# gate down proj
|
|
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
|
set_weight_attrs(
|
|
w2_qweight,
|
|
{
|
|
"input_dim": 1,
|
|
"output_dim": 0,
|
|
"tensor_shape": tensor_shape,
|
|
"is_gguf_weight": True,
|
|
"data_container": [],
|
|
},
|
|
)
|
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
|
layer.register_parameter("w2_qweight", w2_qweight)
|
|
|
|
w2_qweight_type = Parameter(
|
|
torch.empty(1, dtype=torch.uint8), requires_grad=False
|
|
)
|
|
set_weight_attrs(
|
|
w2_qweight_type,
|
|
{"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True},
|
|
)
|
|
|
|
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
|
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
|
|
|
def get_fused_moe_quant_config(
|
|
self, layer: torch.nn.Module
|
|
) -> FusedMoEQuantConfig | None:
|
|
return None
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: int | None = None,
|
|
num_expert_group: int | None = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
custom_routing_function: Callable | None = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: torch.Tensor | None = None,
|
|
logical_to_physical_map: torch.Tensor | None = None,
|
|
logical_replica_count: torch.Tensor | None = None,
|
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
assert self.fused_experts is None
|
|
|
|
if enable_eplb:
|
|
raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.")
|
|
|
|
assert activation == "silu", "Only SiLU activation is supported."
|
|
if apply_router_weight_on_input:
|
|
raise NotImplementedError(
|
|
"Apply router weight on input is not supported for"
|
|
"fused GGUF MoE method."
|
|
)
|
|
|
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
indices_type=self.topk_indices_dtype,
|
|
)
|
|
return fused_moe_gguf(
|
|
x,
|
|
layer.w13_qweight,
|
|
layer.w2_qweight,
|
|
topk_weights,
|
|
topk_ids,
|
|
layer.w13_qweight_type.weight_type,
|
|
layer.w2_qweight_type.weight_type,
|
|
activation,
|
|
)
|
|
|
|
|
|
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
|
"""Embedding method for GGUF.
|
|
|
|
Args:
|
|
quant_config: The GGUF quantization config.
|
|
"""
|
|
|
|
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
qweight = layer.qweight
|
|
qweight_type = layer.qweight_type.weight_type
|
|
hidden_size = qweight.tensor_shape[1]
|
|
|
|
return apply_gguf_embedding(
|
|
x, qweight, qweight_type, hidden_size, dtype=self.params_dtype
|
|
)
|
|
|
|
|
|
class GGUFUninitializedParameter(UninitializedParameter):
|
|
cls_to_become = Parameter
|
|
data_container: list[torch.Tensor]
|