[Quantization] Expand compressed-tensors MoE matching logic to support NFP4 + FP8 MoEs (#22674)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
Dipika Sikka 2025-08-27 01:00:21 -04:00 committed by GitHub
parent 142ac08030
commit d272415e57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 9 deletions

View File

@ -425,6 +425,10 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: BaseModel,
input_quant: BaseModel,
format: Optional[str] = None) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
@ -437,14 +441,14 @@ class CompressedTensorsConfig(QuantizationConfig):
actorder=weight_quant.actorder)
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
if (format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.pack_quantized.value
if (format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
@ -453,10 +457,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
act_quant_format = is_activation_quantization_format(
format
) if format is not None else is_activation_quantization_format(
self.quant_format)
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported(

View File

@ -22,6 +22,8 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
@ -65,12 +67,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
layer: torch.nn.Module
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
# Check if a using "Linear" to select scheems
if "Linear" in quant_config.target_scheme_map:
matched_target = "Linear"
else:
# May have instead defined the linear layers in the fused model
fused_layers = [
"re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"
]
current_scheme = None
for fused_layer in fused_layers:
# Check if one of the fused layers are defined in quant_config
matched_target = find_matched_target(
layer_name=fused_layer,
module=layer,
targets=quant_config.target_scheme_map.keys(),
fused_mapping=quant_config.packed_modules_mapping)
# Only valid if down_proj, gate_proj, and up_proj
# are mapped to the same quant scheme in the quant_config
if current_scheme is None:
current_scheme = quant_config.target_scheme_map.get(
matched_target)
else:
assert current_scheme == quant_config.target_scheme_map.get(
matched_target)
weight_quant = quant_config.target_scheme_map[matched_target].get(
"weights")
input_quant = quant_config.target_scheme_map[matched_target].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):