[refactor] CTMoEMethods to use QuantizationArgs (#28871)

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
HDCharles 2025-12-03 06:00:56 -05:00 committed by GitHub
parent 787b84a9fc
commit b294e28db2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 75 deletions

View File

@ -767,8 +767,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets=self.target_scheme_map.keys(), targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
) )
scheme_dict = self.target_scheme_map[matched_target]
return self.target_scheme_map[matched_target] if scheme_dict.get("format") is None:
scheme_dict["format"] = self.quant_format
return scheme_dict
return None return None

View File

@ -7,7 +7,11 @@ from enum import Enum
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationStrategy,
)
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
@ -142,10 +146,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
weight_quant = scheme_dict.get("weights") weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations") input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise # group_size=None means channelwise
group_size = weight_quant.group_size or -1 group_size = weight_quant.group_size or -1
valid_format_and_bits = (
weight_quant.num_bits in WNA16_SUPPORTED_BITS
and format == CompressionFormat.pack_quantized.value
)
if not valid_format_and_bits:
raise ValueError(
"For Fused MoE layers, only format: ",
f"{CompressionFormat.pack_quantized.value} ",
f" and bits: {WNA16_SUPPORTED_BITS} is supported ",
f"but got format: {CompressionFormat.pack_quantized.value} "
f" and bits: {weight_quant.num_bits}",
)
# Prefer to use the MarlinMoE kernel when it is supported. # Prefer to use the MarlinMoE kernel when it is supported.
if ( if (
not check_moe_marlin_supports_layer(layer, group_size) not check_moe_marlin_supports_layer(layer, group_size)
@ -161,12 +181,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
) )
logger.info_once("Using CompressedTensorsWNA16MoEMethod") logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod( return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name weight_quant, input_quant, layer.moe_config
) )
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod( return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config, layer_name weight_quant, input_quant, layer.moe_config
) )
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
@ -176,15 +196,15 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
or quant_config._is_fp8_w8a8(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)
): ):
return CompressedTensorsW8A8Fp8MoEMethod( return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name weight_quant, input_quant, layer.moe_config
) )
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod( return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name weight_quant, input_quant, layer.moe_config
) )
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod( return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name weight_quant, input_quant, layer.moe_config
) )
else: else:
raise RuntimeError( raise RuntimeError(
@ -650,17 +670,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
): ):
super().__init__(moe) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
self.quant_config = quant_config CompressedTensorsConfig,
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
) )
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
per_tensor = ( per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy == QuantizationStrategy.TENSOR
@ -698,11 +720,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path # cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant self.weight_quant, self.input_quant
) )
self.use_cutlass = not self.block_quant and ( self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) CompressedTensorsConfig._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant
)
or self.is_fp8_w8a8_sm100 or self.is_fp8_w8a8_sm100
) )
self.disable_expert_map = False self.disable_expert_map = False
@ -1261,16 +1285,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.weight_quant = weight_quant
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = input_quant
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
)
per_channel = ( per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@ -1414,36 +1436,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.weight_quant = weight_quant
# TODO: @dsikka: refactor this to use schemes as other kernels self.input_quant = input_quant
# are supported + check if the layer is being ignored. assert weight_quant.symmetric, (
config = self.quant_config.target_scheme_map["Linear"].get("weights") "Only symmetric quantization is supported for MoE"
self.num_bits = config.num_bits )
self.packed_factor = 32 // config.num_bits # Extract properties from weight_quant
self.strategy = config.strategy self.num_bits = weight_quant.num_bits
self.group_size = config.group_size self.packed_factor = 32 // weight_quant.num_bits
self.actorder = config.actorder self.strategy = weight_quant.strategy
self.layer_name = layer_name self.group_size = weight_quant.group_size
self.marlin_input_dtype = get_marlin_input_dtype(layer_name) self.actorder = weight_quant.actorder
assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not (
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits]
self.use_marlin = True self.use_marlin = True
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
def create_weights( def create_weights(
self, self,
@ -1812,35 +1825,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 weight_quant: QuantizationArgs,
input_quant: QuantizationArgs | None,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.weight_quant = weight_quant
# TODO: @dsikka: refactor this to use schemes as other kernels self.input_quant = input_quant
# are supported + check if the layer is being ignored. # Extract properties from weight_quant
config = self.quant_config.target_scheme_map["Linear"].get("weights") self.num_bits = weight_quant.num_bits
self.num_bits = config.num_bits self.packed_factor = 32 // weight_quant.num_bits
self.packed_factor = 32 // config.num_bits self.strategy = weight_quant.strategy
self.strategy = config.strategy
# channelwise is not supported by this kernel # channelwise is not supported by this kernel
assert config.strategy == "group" assert weight_quant.strategy == "group"
self.group_size = config.group_size self.group_size = weight_quant.group_size
# grouped actorder isn't supported by this kernel # grouped actorder isn't supported by this kernel
assert config.actorder != "group" assert weight_quant.actorder != "group"
assert config.symmetric, "Only symmetric quantization is supported for MoE" assert weight_quant.symmetric, (
"Only symmetric quantization is supported for MoE"
if not ( )
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS
):
raise ValueError(
"For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}",
)
def create_weights( def create_weights(
self, self,
@ -2065,28 +2069,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None, layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.has_bias = self.moe.has_bias self.has_bias = self.moe.has_bias
self.quant_config = quant_config self.weight_quant = weight_quant
self.input_quant = input_quant
# Validate scheme: weights=W4 (channel or group), # Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8) # activations=dynamic TOKEN (A8)
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
aq = self.quant_config.target_scheme_map["Linear"].get("input_activations")
# Must be dynamic per-token activations # Must be dynamic per-token activations
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: if (
input_quant.strategy != QuantizationStrategy.TOKEN
or not input_quant.dynamic
):
raise ValueError( raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization." "W4A8-int MoE needs dynamic per-token activation quantization."
) )
# Weight can be channel-wise (group_size=None) or group-wise # Weight can be channel-wise (group_size=None) or group-wise
self.group_size = wq.group_size if (wq.group_size is not None) else -1 self.group_size = (
if wq.num_bits != 4: weight_quant.group_size if (weight_quant.group_size is not None) else -1
)
if weight_quant.num_bits != 4:
raise ValueError("This method only supports 4-bit weights (num_bits=4).") raise ValueError("This method only supports 4-bit weights (num_bits=4).")
# CPU only # CPU only