mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 06:09:09 +08:00
[Refactor] Merge Compressed Tensor FP8 CompressedTensorsW8A8Fp8MoEMethod and CompressedTensorsW8A8Fp8MoECutlassMethod (#21775)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
947e982ede
commit
48b763d6b5
@ -45,7 +45,6 @@ class GPTQMarlinState(Enum):
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||||
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
|
||||||
"CompressedTensorsW8A8Int8MoEMethod",
|
"CompressedTensorsW8A8Int8MoEMethod",
|
||||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||||
"CompressedTensorsW4A4MoeMethod"
|
"CompressedTensorsW4A4MoeMethod"
|
||||||
@ -84,9 +83,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A4MoeMethod()
|
return CompressedTensorsW4A4MoeMethod()
|
||||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)):
|
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
||||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
|
||||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||||
@ -378,6 +376,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
|
||||||
|
# cutlass path
|
||||||
|
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||||
|
self.weight_quant, self.input_quant)
|
||||||
|
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
||||||
|
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
||||||
|
self.fused_experts = None # type: ignore[assignment]
|
||||||
|
self.disable_expert_map = False
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -558,6 +564,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
# cutlass path
|
||||||
|
if self.use_cutlass:
|
||||||
|
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
|
||||||
|
|
||||||
|
use_batched_format = (prepare_finalize.activation_format ==
|
||||||
|
FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
|
||||||
|
num_dispatchers = prepare_finalize.num_dispatchers()
|
||||||
|
num_experts = (moe.num_local_experts
|
||||||
|
if use_batched_format else moe.num_experts)
|
||||||
|
|
||||||
|
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||||
|
|
||||||
|
experts = CutlassExpertsFp8(
|
||||||
|
num_experts,
|
||||||
|
moe.in_dtype,
|
||||||
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
|
num_dispatchers=num_dispatchers,
|
||||||
|
use_batched_format=use_batched_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.disable_expert_map = (num_dispatchers > 1
|
||||||
|
or not experts.supports_expert_map())
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|
||||||
|
# triton path
|
||||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts)
|
BatchedTritonExperts)
|
||||||
@ -629,6 +663,68 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# cutlass path
|
||||||
|
if self.use_cutlass:
|
||||||
|
per_act_token = (
|
||||||
|
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||||
|
per_channel_quant = (
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
|
|
||||||
|
# small-batch fallback on SM100
|
||||||
|
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale)
|
||||||
|
|
||||||
|
if self.fused_experts is None:
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp8)
|
||||||
|
return cutlass_moe_fp8(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
per_act_token=per_act_token,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
return self.rocm_aiter_fused_experts_func(
|
return self.rocm_aiter_fused_experts_func(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -685,291 +781,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
a2_scale=layer.w2_input_scale)
|
a2_scale=layer.w2_input_scale)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
|
||||||
):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
|
||||||
"weights")
|
|
||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
|
||||||
"input_activations")
|
|
||||||
|
|
||||||
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
|
||||||
and self.input_quant.strategy
|
|
||||||
== QuantizationStrategy.TENSOR)
|
|
||||||
per_channel = (
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
|
||||||
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
|
||||||
if not (per_tensor or per_channel):
|
|
||||||
raise ValueError(
|
|
||||||
"For FP8 Fused MoE layers, we require per tensor "
|
|
||||||
"or channelwise, dynamic per token quantization. Found "
|
|
||||||
f"{self.weight_quant}, {self.input_quant}")
|
|
||||||
|
|
||||||
self.static_input_scales = not self.input_quant.dynamic
|
|
||||||
if self.static_input_scales and per_channel:
|
|
||||||
raise ValueError(
|
|
||||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
|
||||||
"channelwise, dynamic per token quantization.")
|
|
||||||
|
|
||||||
self.topk_indices_dtype = None
|
|
||||||
self.fused_experts = None # type: ignore
|
|
||||||
self.disable_expert_map = False
|
|
||||||
self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100(
|
|
||||||
self.weight_quant, self.input_quant)
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
||||||
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
w13_weight = torch.nn.Parameter(torch.empty(
|
|
||||||
num_experts,
|
|
||||||
2 * intermediate_size_per_partition,
|
|
||||||
hidden_size,
|
|
||||||
dtype=params_dtype),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(torch.empty(
|
|
||||||
num_experts,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size_per_partition,
|
|
||||||
dtype=params_dtype),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
# They are combined to a single scale after weight loading.
|
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts, 2, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
|
||||||
extra_weight_attrs.update(
|
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts,
|
|
||||||
2 * intermediate_size_per_partition,
|
|
||||||
1,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts, hidden_size, 1, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
|
||||||
extra_weight_attrs.update(
|
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
|
||||||
if self.static_input_scales:
|
|
||||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
||||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
|
||||||
num_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
||||||
else:
|
|
||||||
layer.w13_input_scale = None
|
|
||||||
layer.w2_input_scale = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
# Fp8 moe kernels require a single activation scale.
|
|
||||||
# We take the max of all the scales in case they differ.
|
|
||||||
if self.static_input_scales:
|
|
||||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
|
|
||||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None.")
|
|
||||||
if (not all_close_1d(layer.w13_input_scale)
|
|
||||||
or not all_close_1d(layer.w2_input_scale)):
|
|
||||||
logger.warning_once(
|
|
||||||
"Found input_scales that are not equal for "
|
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
|
||||||
"for each layer.")
|
|
||||||
layer.w13_input_scale = torch.nn.Parameter(
|
|
||||||
layer.w13_input_scale.max(), requires_grad=False)
|
|
||||||
layer.w2_input_scale = torch.nn.Parameter(
|
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
|
||||||
|
|
||||||
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
|
|
||||||
# for w13 per expert. Use max then dequant and requant each expert.
|
|
||||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
|
|
||||||
assert layer.w13_weight_scale is not None
|
|
||||||
shard_size = layer.intermediate_size_per_partition
|
|
||||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
|
||||||
for expert_id in range(layer.local_num_experts):
|
|
||||||
start = 0
|
|
||||||
for shard_id in range(2):
|
|
||||||
dq_weight = per_tensor_dequantize(
|
|
||||||
layer.w13_weight[expert_id][start:start +
|
|
||||||
shard_size, :],
|
|
||||||
layer.w13_weight_scale[expert_id][shard_id])
|
|
||||||
layer.w13_weight[expert_id][
|
|
||||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
|
||||||
dq_weight, max_w13_scales[expert_id])
|
|
||||||
start += shard_size
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def select_gemm_impl(
|
|
||||||
self,
|
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
||||||
moe: FusedMoEConfig,
|
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
|
||||||
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
|
|
||||||
|
|
||||||
use_batched_format = (prepare_finalize.activation_format ==
|
|
||||||
FusedMoEActivationFormat.BatchedExperts)
|
|
||||||
|
|
||||||
num_dispatchers = prepare_finalize.num_dispatchers()
|
|
||||||
|
|
||||||
num_experts = (moe.num_local_experts
|
|
||||||
if use_batched_format else moe.num_experts)
|
|
||||||
|
|
||||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
|
||||||
|
|
||||||
experts = CutlassExpertsFp8(
|
|
||||||
num_experts,
|
|
||||||
moe.in_dtype,
|
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
|
||||||
num_dispatchers=num_dispatchers,
|
|
||||||
use_batched_format=use_batched_format,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.disable_expert_map = (num_dispatchers > 1
|
|
||||||
or not experts.supports_expert_map())
|
|
||||||
|
|
||||||
return experts
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
scoring_func: str = "softmax",
|
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
enable_eplb: bool = False,
|
|
||||||
expert_load_view: Optional[torch.Tensor] = None,
|
|
||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if enable_eplb:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"EPLB not supported for "
|
|
||||||
"`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
|
||||||
|
|
||||||
per_act_token = (
|
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
|
||||||
per_channel_quant = (
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
|
||||||
# Triton fused_experts is faster in small batch sizes on SM100.
|
|
||||||
# Fall back to fused_experts in small batch sizes.
|
|
||||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
return fused_experts(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale)
|
|
||||||
if self.fused_experts is None:
|
|
||||||
# If no modular kernel is provided, use cutlass_moe_fp8
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
||||||
cutlass_moe_fp8)
|
|
||||||
return cutlass_moe_fp8(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
per_act_token=per_act_token,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.fused_experts(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user