[Refactor] Merge Compressed Tensor FP8 CompressedTensorsW8A8Fp8MoEMethod and CompressedTensorsW8A8Fp8MoECutlassMethod (#21775)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-28 21:47:21 -04:00 committed by GitHub
parent 947e982ede
commit 48b763d6b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,7 +45,6 @@ class GPTQMarlinState(Enum):
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
@ -84,9 +83,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
@ -378,6 +376,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -558,6 +564,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts)
num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
num_experts,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
)
self.disable_expert_map = (num_dispatchers > 1
or not experts.supports_expert_map())
return experts
# triton path
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
@ -629,6 +663,68 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
# cutlass path
if self.use_cutlass:
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func(
hidden_states=x,
@ -685,291 +781,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale)
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
== QuantizationStrategy.TENSOR)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not (per_tensor or per_channel):
raise ValueError(
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.")
self.topk_indices_dtype = None
self.fused_experts = None # type: ignore
self.disable_expert_map = False
self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, 2, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(
num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer.")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts)
num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
num_experts,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
)
self.disable_expert_map = (num_dispatchers > 1
or not experts.supports_expert_map())
return experts
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN)
per_channel_quant = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
# Triton fused_experts is faster in small batch sizes on SM100.
# Fall back to fused_experts in small batch sizes.
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
if self.fused_experts is None:
# If no modular kernel is provided, use cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
per_act_token=per_act_token,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(