mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:55:01 +08:00
[Bugfix] Fix FP8 Marlin MoE and enable for compressed-tensors models (#18026)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6266c57bae
commit
9a2a6357de
@ -9,6 +9,7 @@ from compressed_tensors import CompressionFormat
|
|||||||
from compressed_tensors.quantization import (ActivationOrdering,
|
from compressed_tensors.quantization import (ActivationOrdering,
|
||||||
QuantizationStrategy)
|
QuantizationStrategy)
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -20,10 +21,13 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||||
marlin_moe_permute_scales)
|
marlin_moe_permute_scales)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -114,10 +118,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
"channelwise, dynamic per token quantization.")
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
|
# kernel for fast weight-only FP8 quantization
|
||||||
|
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||||
|
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||||
|
# Disable marlin for rocm
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
@ -280,6 +298,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
self.fused_experts_func = fused_experts
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
|
# Activations not quantized for marlin.
|
||||||
|
del layer.w13_input_scale
|
||||||
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -311,6 +335,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
assert activation == "silu", (
|
||||||
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
assert not apply_router_weight_on_input, (
|
||||||
|
"Apply router weight on input not supported for Marlin MoE.")
|
||||||
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
router_logits,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
return self.fused_experts_func(
|
return self.fused_experts_func(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@ -517,7 +559,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert activation == "silu"
|
assert activation == "silu", (
|
||||||
|
f"{activation} not supported for Cutlass MoE.")
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -942,11 +985,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", (
|
||||||
if apply_router_weight_on_input:
|
f"{activation} not supported for Marlin MoE.")
|
||||||
raise NotImplementedError(
|
assert not apply_router_weight_on_input, (
|
||||||
"Apply router weight on input is not supported for "
|
"Apply router weight on input not supported for Marlin MoE.")
|
||||||
"fused Marlin MoE method.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@ -811,6 +811,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
|
assert activation == "silu", (
|
||||||
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
assert not apply_router_weight_on_input, (
|
||||||
|
"Apply router weight on input not supported for Marlin MoE.")
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@ -268,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
|
|||||||
tensor_list.append(marlin_scales)
|
tensor_list.append(marlin_scales)
|
||||||
|
|
||||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||||
|
scales = fp8_fused_exponent_bias_into_scales(scales)
|
||||||
scales = torch.nn.Parameter(scales, requires_grad=False)
|
scales = torch.nn.Parameter(scales, requires_grad=False)
|
||||||
|
|
||||||
setattr(layer, name + "_weight_scale", scales)
|
setattr(layer, name + "_weight_scale", scales)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user