mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-26 04:12:34 +08:00
[Quantization] Expand compressed-tensors MoE matching logic to support NFP4 + FP8 MoEs (#22674)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
142ac08030
commit
d272415e57
@ -425,6 +425,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
weight_quant: BaseModel,
|
||||
input_quant: BaseModel,
|
||||
format: Optional[str] = None) -> "CompressedTensorsScheme":
|
||||
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
@ -437,14 +441,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
if (format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size)
|
||||
if (self.quant_format == CompressionFormat.pack_quantized.value
|
||||
if (format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
@ -453,10 +457,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(
|
||||
format
|
||||
) if format is not None else is_activation_quantization_format(
|
||||
self.quant_format)
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
if cutlass_fp4_supported(
|
||||
|
||||
@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||
@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
layer: torch.nn.Module
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||
input_quant = quant_config.target_scheme_map["Linear"].get(
|
||||
# Check if a using "Linear" to select scheems
|
||||
if "Linear" in quant_config.target_scheme_map:
|
||||
matched_target = "Linear"
|
||||
else:
|
||||
# May have instead defined the linear layers in the fused model
|
||||
|
||||
fused_layers = [
|
||||
"re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
|
||||
]
|
||||
current_scheme = None
|
||||
for fused_layer in fused_layers:
|
||||
# Check if one of the fused layers are defined in quant_config
|
||||
matched_target = find_matched_target(
|
||||
layer_name=fused_layer,
|
||||
module=layer,
|
||||
targets=quant_config.target_scheme_map.keys(),
|
||||
fused_mapping=quant_config.packed_modules_mapping)
|
||||
|
||||
# Only valid if down_proj, gate_proj, and up_proj
|
||||
# are mapped to the same quant scheme in the quant_config
|
||||
if current_scheme is None:
|
||||
current_scheme = quant_config.target_scheme_map.get(
|
||||
matched_target)
|
||||
else:
|
||||
assert current_scheme == quant_config.target_scheme_map.get(
|
||||
matched_target)
|
||||
|
||||
weight_quant = quant_config.target_scheme_map[matched_target].get(
|
||||
"weights")
|
||||
input_quant = quant_config.target_scheme_map[matched_target].get(
|
||||
"input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user