diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ae16a20cfaab9..4a3fc2a1a6b9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,7 +2,7 @@ import enum from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from compressed_tensors import CompressionFormat @@ -14,9 +14,12 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_moe_marlin_supports_layer, marlin_make_workspace_new, + marlin_moe_permute_scales) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -54,18 +57,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - # Prefer to use the non-marlin kernel when: - # 1. Many experts (MarlinMoE gives poor performance when >= 16) - # 2. Non-FP16 dtype (MarlinMoE only supports FP16) - # 3. Actorder is not group/dynamic (g_idx is unsupported) - # 4. Scaled are grouped (channelwise is unsupported) - if ((layer.local_num_experts >= 16 - or layer.params_dtype != torch.float16) and - weight_quant.actorder not in (ActivationOrdering.GROUP, - ActivationOrdering.DYNAMIC) - and weight_quant.strategy in QuantizationStrategy.GROUP): + # Prefer to use the MarlinMoE kernel when it is supported. + if not check_moe_marlin_supports_layer(layer, + weight_quant.group_size): + if (weight_quant.strategy in QuantizationStrategy.GROUP and + weight_quant.actorder in (ActivationOrdering.GROUP, + ActivationOrdering.DYNAMIC)): + raise ValueError( + "WNA16MoE is not supported with actorder=group/dynamic." + ) + logger.info_once("Using CompressedTensorsWNA16MoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) and layer.activation == "silu"): @@ -705,15 +709,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): f"{CompressionFormat.pack_quantized.value} ", "is supported for the following bits: ", f"{WNA16_SUPPORTED_BITS}") + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - assert params_dtype == torch.float16, ( - "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 - ) - intermediate_size_full = extra_weight_attrs.pop( "intermediate_size_full") @@ -837,50 +838,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.marlin_state = GPTQMarlinState.REPACK def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, - scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - size_k2 = layer.w2_weight_packed.shape[2] - size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device @@ -938,7 +895,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.w13_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, @@ -946,25 +903,25 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_weight_scale, - size_k13, - layer.w13_weight_scale.shape[2], - self.group_size, - self.num_bits, + s=layer.w13_weight_scale, + size_k=layer.w13_weight_packed.shape[2], + size_n=layer.w13_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w13_weight_scale", marlin_w13_scales) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] * + s=layer.w2_weight_scale, + size_k=layer.w2_weight_scale.shape[1] * (self.group_size if self.group_size != -1 else self.packed_factor), - size_k2, - self.group_size, - self.num_bits, + size_n=layer.w2_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w2_weight_scale", marlin_w2_scales) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + + layer.workspace = marlin_make_workspace_new(device, 4) def apply( self, @@ -985,10 +942,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for " @@ -1015,11 +968,14 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): router_logits, topk_weights, topk_ids, + quant_type_id=self.quant_type.id, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits, + workspace=layer.workspace, is_k_full=self.is_k_full) @@ -1203,7 +1159,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -1223,6 +1179,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_int4_w4a16=self.num_bits == 4, use_int8_w8a16=self.num_bits == 8, global_num_experts=global_num_experts,