mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 09:03:36 +08:00
[Misc] Support FP8 MoE for compressed-tensors (#8588)
This commit is contained in:
parent
64840dfae4
commit
873edda6cf
@ -1,4 +1,5 @@
|
|||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
|
||||||
|
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
|
||||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
||||||
|
|||||||
@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
loaded_weight: torch.Tensor, weight_name: str,
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
shard_id: str, expert_id: int) -> None:
|
shard_id: str, expert_id: int) -> None:
|
||||||
|
|
||||||
# compressed-tensors represents weights on disk which are flipped
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
|
# against known CompressionFormat enum values that have this quality
|
||||||
loaded_weight = loaded_weight.t().contiguous() if (
|
loaded_weight = loaded_weight.t().contiguous() if (
|
||||||
self.quant_method.__class__.__name__
|
self.quant_method.__class__.__name__
|
||||||
== "CompressedTensorsMoEMethod") else loaded_weight
|
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
|
||||||
|
|
||||||
if shard_id not in ("w1", "w2", "w3"):
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||||
@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
|
# this is needed for compressed-tensors only
|
||||||
|
loaded_weight = loaded_weight.to(param.data.device)
|
||||||
|
|
||||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||||
loaded_weight).abs() > 1e-5:
|
loaded_weight).abs() > 1e-5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
if isinstance(layer, Attention):
|
if isinstance(layer, Attention):
|
||||||
return CompressedTensorsKVCacheMethod(self)
|
return CompressedTensorsKVCacheMethod(self)
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
return CompressedTensorsMoEMethod(self)
|
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -5,12 +5,16 @@ from typing import Callable, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
WNA16_SUPPORTED_BITS)
|
WNA16_SUPPORTED_BITS)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat)
|
CompressionFormat, QuantizationStrategy)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinState(Enum):
|
class GPTQMarlinState(Enum):
|
||||||
@ -18,11 +22,219 @@ class GPTQMarlinState(Enum):
|
|||||||
READY = enum.auto()
|
READY = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsMoEMethod"]
|
__all__ = [
|
||||||
|
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||||
|
"CompressedTensorsWNA16MoEMethod"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_moe_method(
|
||||||
|
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
|
) -> "CompressedTensorsMoEMethod":
|
||||||
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
|
# are supported + check if the layer is being ignored.
|
||||||
|
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
|
input_quant = quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
|
||||||
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||||
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
|
):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"weights")
|
||||||
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
|
||||||
|
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
|
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
|
||||||
|
raise ValueError(
|
||||||
|
"For FP8 Fused MoE layers, only per-tensor scales"
|
||||||
|
"for weights and activations are supported. Found "
|
||||||
|
f"{self.weight_quant}, {self.input_quant}")
|
||||||
|
|
||||||
|
self.static_input_scales = not self.input_quant.dynamic
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
2,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.static_input_scales:
|
||||||
|
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
else:
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# Fp8 moe kernels require a single activation scale.
|
||||||
|
# We take the max of all the scales in case they differ.
|
||||||
|
if self.static_input_scales:
|
||||||
|
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None.")
|
||||||
|
if (not all_close_1d(layer.w13_input_scale)
|
||||||
|
or not all_close_1d(layer.w2_input_scale)):
|
||||||
|
print_warning_once(
|
||||||
|
"Found input_scales that are not equal for "
|
||||||
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
|
"for each layer. ")
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(
|
||||||
|
layer.w13_input_scale.max(), requires_grad=False)
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
|
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if is_hip():
|
||||||
|
# Normalize the weights and scales
|
||||||
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale,
|
||||||
|
layer.w13_input_scale)
|
||||||
|
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale,
|
||||||
|
layer.w2_input_scale)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max then dequant and requant each expert.
|
||||||
|
assert layer.w13_weight_scale is not None
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2):
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
|
layer.w13_weight[expert_id][
|
||||||
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
|
dq_weight, max_w13_scales[expert_id])
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function)
|
||||||
|
|
||||||
|
return fused_experts(x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
|
|||||||
@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
|
|||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user