mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:26:16 +08:00
[Bugfix] Fix FP8 Marlin MoE and enable for compressed-tensors models (#18026)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6266c57bae
commit
9a2a6357de
@ -9,6 +9,7 @@ from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@ -20,10 +21,13 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -114,10 +118,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||
"channelwise, dynamic per token quantization.")
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
@ -280,6 +298,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -311,6 +335,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -517,7 +559,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu"
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Cutlass MoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -942,11 +985,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for "
|
||||
"fused Marlin MoE method.")
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@ -811,6 +811,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Apply router weight on input not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
@ -268,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
scales = fp8_fused_exponent_bias_into_scales(scales)
|
||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user