mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:34:59 +08:00
538 lines
22 KiB
Python
538 lines
22 KiB
Python
from abc import abstractmethod
|
|
from enum import Enum
|
|
from typing import Callable, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.platforms import current_platform
|
|
|
|
if current_platform.is_cuda_alike():
|
|
from .fused_moe import fused_experts
|
|
else:
|
|
fused_experts = None # type: ignore
|
|
if current_platform.is_tpu():
|
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
|
else:
|
|
fused_moe_pallas = None # type: ignore
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class FusedMoeWeightScaleSupported(Enum):
|
|
TENSOR = "tensor"
|
|
CHANNEL = "channel"
|
|
GROUP = "group"
|
|
|
|
|
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
|
|
|
@abstractmethod
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
router_logits: torch.Tensor, top_k: int, renormalize: bool,
|
|
use_grouped_topk: bool) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
@CustomOp.register("unquantized_fused_moe")
|
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
"""MoE method without quantization."""
|
|
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
|
2 * intermediate_size,
|
|
hidden_size,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
|
hidden_size,
|
|
intermediate_size,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None
|
|
) -> torch.Tensor:
|
|
return self.forward(x=x,
|
|
layer=layer,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
use_grouped_topk=use_grouped_topk,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None
|
|
) -> torch.Tensor:
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function)
|
|
|
|
return fused_experts(hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True)
|
|
|
|
def forward_cpu(self, *args, **kwargs):
|
|
raise NotImplementedError(
|
|
"The CPU backend currently does not support MoE.")
|
|
|
|
def forward_tpu(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None
|
|
) -> torch.Tensor:
|
|
assert not use_grouped_topk
|
|
assert num_expert_group is None
|
|
assert topk_group is None
|
|
assert custom_routing_function is None
|
|
return fused_moe_pallas(hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk=top_k,
|
|
gating_output=router_logits,
|
|
renormalize=renormalize)
|
|
|
|
forward_native = forward_cuda
|
|
|
|
|
|
class FusedMoE(torch.nn.Module):
|
|
"""FusedMoE layer for MoE models.
|
|
|
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
|
|
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
|
copy that naming convention here and handle any remapping in the
|
|
load_weights function in each model implementation.
|
|
|
|
Args:
|
|
num_experts: Number of experts in the model
|
|
top_k: Number of experts selected for each token
|
|
hidden_size: Input hidden state size of the transformer
|
|
intermediate_size: Intermediate size of the experts
|
|
params_dtype: Data type for the parameters.
|
|
reduce_results: Whether to all all_reduce on the output of the layer
|
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
|
quant_config: Quantization configure.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
reduce_results: bool = False,
|
|
renormalize: bool = True,
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
custom_routing_function: Optional[Callable] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
self.tp_size = (tp_size if tp_size is not None else
|
|
get_tensor_model_parallel_world_size())
|
|
self.top_k = top_k
|
|
self.num_experts = num_experts
|
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
|
self.reduce_results = reduce_results
|
|
self.renormalize = renormalize
|
|
self.use_grouped_topk = use_grouped_topk
|
|
if self.use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.num_expert_group = num_expert_group
|
|
self.topk_group = topk_group
|
|
self.custom_routing_function = custom_routing_function
|
|
|
|
if quant_config is None:
|
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
|
UnquantizedFusedMoEMethod())
|
|
else:
|
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
|
assert self.quant_method is not None
|
|
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
num_experts=num_experts,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=self.intermediate_size_per_partition,
|
|
params_dtype=params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
|
|
def _load_per_tensor_weight_scale(self, shard_id: str,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
expert_id: int):
|
|
param_data = param.data
|
|
# for per tensor weight quantization
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
# If we are in the row parallel case (down_proj)
|
|
elif shard_id == "w2":
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
|
|
expert_data: torch.Tensor,
|
|
shard_id: str,
|
|
loaded_weight: torch.Tensor,
|
|
tp_rank: int):
|
|
# Load grouped weight scales for group quantization
|
|
# or model weights
|
|
if shard_id == "w2":
|
|
self._load_w2(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
|
|
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
|
shard_dim: int, shard_id: str,
|
|
loaded_weight: torch.Tensor,
|
|
tp_rank: int):
|
|
# for per channel weight quantization
|
|
if shard_id == "w2":
|
|
expert_data.copy_(loaded_weight)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
|
|
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
|
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
|
shard_size = expert_data.shape[shard_dim] // 2
|
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
|
shard_size)
|
|
# Narrow parameter and load.
|
|
# w1, gate_proj: Load into first logical weight of w13.
|
|
if shard_id == "w1":
|
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
# w3, up_proj: Load into second logical weight of w13.
|
|
else:
|
|
assert shard_id == "w3"
|
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
|
|
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
|
# Narrow parameter and load.
|
|
shard_size = expert_data.shape[shard_dim]
|
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
|
shard_size)
|
|
# w2, down_proj: Load into only logical weight of w2.
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_single_value(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, expert_id: int):
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
|
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
|
|
|
if shard_id == "w2":
|
|
self._load_w2(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
else:
|
|
assert shard_id in ("w1", "w3")
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def weight_loader(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, weight_name: str,
|
|
shard_id: str, expert_id: int) -> None:
|
|
|
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
|
# against known CompressionFormat enum values that have this quality
|
|
loaded_weight = loaded_weight.t().contiguous() if (
|
|
self.quant_method.__class__.__name__
|
|
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
|
|
|
if shard_id not in ("w1", "w2", "w3"):
|
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
|
f"got {shard_id}.")
|
|
|
|
WEIGHT_SCALE_SUPPORTED = [
|
|
e.value for e in FusedMoeWeightScaleSupported
|
|
]
|
|
# Fetch the dim to shard the parameter/loaded weight
|
|
# based on the shard id. This will be whatever
|
|
# dimension intermediate_size is used.
|
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
|
|
|
expert_data = param.data[expert_id]
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
# is_transposed: if the dim to shard the weight
|
|
# should be flipped. Required by GPTQ, compressed-tensors
|
|
# should be whatever dimension intermediate_size is
|
|
is_transposed = getattr(param, "is_transposed", False)
|
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
|
if is_transposed:
|
|
shard_dim = ~shard_dim
|
|
|
|
# Case input scale: input_scale loading is only supported for fp8
|
|
if "input_scale" in weight_name:
|
|
# this is needed for compressed-tensors only
|
|
loaded_weight = loaded_weight.to(param.data.device)
|
|
|
|
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
|
loaded_weight).abs() > 1e-5:
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param.data[expert_id]} "
|
|
f"vs. {loaded_weight}")
|
|
|
|
self._load_single_value(param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
return
|
|
|
|
# Case g_idx
|
|
if "g_idx" in weight_name:
|
|
self._load_g_idx(shard_dim=0,
|
|
shard_id=shard_id,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
return
|
|
|
|
# Case weight scales and zero_points
|
|
if ("scale" in weight_name or "zero" in weight_name):
|
|
# load the weight scales and zp based on the quantization scheme
|
|
# supported weight scales/zp can be found in
|
|
# FusedMoeWeightScaleSupported
|
|
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
|
# specific to each case
|
|
quant_method = getattr(param, "quant_method", None)
|
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
|
self._load_per_channel_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
|
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
|
param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
else:
|
|
raise ValueError(
|
|
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
|
return
|
|
|
|
# Case weight_shape
|
|
if "weight_shape" in weight_name:
|
|
# only required by compressed-tensors
|
|
self._load_single_value(param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
return
|
|
|
|
# Case model weights
|
|
if "weight" in weight_name:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
return
|
|
|
|
@staticmethod
|
|
def select_experts(hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None):
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
fused_topk, grouped_topk)
|
|
|
|
# DeekSeekv2 uses grouped_top_k
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
topk_weights, topk_ids = grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group)
|
|
elif custom_routing_function is None:
|
|
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize)
|
|
else:
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize)
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
def forward(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function)
|
|
|
|
if self.reduce_results and self.tp_size > 1:
|
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
final_hidden_states)
|
|
|
|
return final_hidden_states
|
|
|
|
@classmethod
|
|
def make_expert_params_mapping(
|
|
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
|
|
ckpt_up_proj_name: str,
|
|
num_experts: int) -> List[Tuple[str, str, int, str]]:
|
|
|
|
return [
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
("experts.w13_" if weight_name
|
|
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
|
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
|
|
for expert_id in range(num_experts) for shard_id, weight_name in [
|
|
("w1", ckpt_gate_proj_name),
|
|
("w2", ckpt_down_proj_name),
|
|
("w3", ckpt_up_proj_name),
|
|
]
|
|
]
|
|
|
|
def _load_fp8_scale(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, weight_name: str,
|
|
shard_id: str, expert_id: int) -> None:
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
if "input_scale" in weight_name:
|
|
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
|
loaded_weight).abs() > 1e-5:
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param_data[expert_id]} "
|
|
f"vs. {loaded_weight}")
|
|
param_data[expert_id] = loaded_weight
|
|
# Weight scales
|
|
elif "weight_scale" in weight_name:
|
|
# If we are in merged column case (gate_up_proj)
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
# If we are in the row parallel case (down_proj)
|
|
else:
|
|
param_data[expert_id] = loaded_weight
|