[Bugfix] Make compressed-tensors MoEs respect ignored layers (#28878)

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
This commit is contained in:
HDCharles 2025-11-26 21:35:13 -05:00 committed by GitHub
parent ba1fcd84a7
commit df01eda4dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 133 additions and 52 deletions

View File

@ -632,6 +632,7 @@ steps:
# we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
- uv pip install --system conch-triton-kernels
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min

View File

@ -10,6 +10,7 @@ import torch
from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensors24,
CompressedTensorsLinearMethod,
@ -767,3 +768,50 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="This test is not for non-CUDA platforms",
)
def test_compressed_tensors_moe_ignore_with_model(vllm_runner):
"""
Integration test for MoE layer ignore functionality with a real model.
This test would verify that when loading a compressed-tensors quantized
MoE model where some MoE layers are in the ignore list, those layers
use UnquantizedFusedMoEMethod while non-ignored layers use the
quantized method.
Expected model structure:
- Compressed-tensors quantized MoE model (e.g., Mixtral-based)
- Config with ignore list containing specific MoE layers
- Multiple MoE layers where some are quantized and some are not
"""
# model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3
model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable" # CT 12.2
with vllm_runner(model_path, enforce_eager=True) as llm:
def check_model(model):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod,
)
# Check layer 0 MoE (should be quantized)
layer_quantized = model.model.layers[0].mlp.experts
assert isinstance(layer_quantized, FusedMoE)
assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod)
# Check layer 10 MoE (should be unquantized + ignored)
layer_unquantized = model.model.layers[3].mlp.experts
assert isinstance(layer_unquantized, FusedMoE)
assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod)
llm.apply_model(check_model)
# Verify the model can generate output
output = llm.generate_greedy("Hello, my name is", max_tokens=4)
assert output

View File

@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
from vllm.triton_utils import HAS_TRITON
@ -41,6 +44,7 @@ __all__ = [
"FusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",

View File

@ -158,9 +158,23 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return None
def _add_fused_moe_to_target_scheme_map(self):
"""
Helper function to update target_scheme_map
since linear layers get fused into FusedMoE
targetting 'Linear' needs to also match
FusedMoE modules.
"""
if (
"Linear" not in self.target_scheme_map
or "FusedMoE" in self.target_scheme_map
):
return
self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
@ -655,25 +669,13 @@ class CompressedTensorsConfig(QuantizationConfig):
to select the CompressedTensorsScheme used for inference.
"""
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return None
# Use the new get_quant_args method to extract QuantizationArgs
scheme_dict = self.get_scheme_dict(layer, layer_name)
# Will be empty for models with only sparsity
weight_quant = input_quant = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = None
input_quant = None
format = None
if scheme_dict:
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
@ -732,6 +734,38 @@ class CompressedTensorsConfig(QuantizationConfig):
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
return scheme
def get_scheme_dict(
self, layer: torch.nn.Module, layer_name: str | None = None
) -> dict[str, QuantizationArgs | str | None] | None:
"""
Extract the QuantizationArgs for a given layer.
Returns:
dict with {
"weights": QuantizationArgs,
"input_activations": QuantizationArgs | None,
"format": str | None
} | None
"""
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return None
# Will be empty for models with only sparsity
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
return self.target_scheme_map[matched_target]
return None
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales

View File

@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@ -45,9 +46,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
@ -113,39 +111,35 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
prefix: str,
) -> "CompressedTensorsMoEMethod":
# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [
prefix + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
all_scheme_dicts = [
quant_config.get_scheme_dict(layer, name) for name in unfused_names
]
scheme_dict = all_scheme_dicts.pop()
# multiple schemes found
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
raise ValueError(
"All MoE projections need to have same "
"quantization scheme but found multiple"
)
if scheme_dict is None: # ignored layer
return UnquantizedFusedMoEMethod(layer.moe_config)
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
# Check if a using "Linear" to select schemes
if "Linear" in quant_config.target_scheme_map:
matched_target = "Linear"
else:
# May have instead defined the linear layers in the fused model
fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"]
current_scheme = None
for fused_layer in fused_layers:
# Check if one of the fused layers are defined in quant_config
matched_target = find_matched_target(
layer_name=fused_layer,
module=layer,
targets=quant_config.target_scheme_map.keys(),
fused_mapping=quant_config.packed_modules_mapping,
)
# Only valid if down_proj, gate_proj, and up_proj
# are mapped to the same quant scheme in the quant_config
if current_scheme is None:
current_scheme = quant_config.target_scheme_map.get(matched_target)
else:
assert current_scheme == quant_config.target_scheme_map.get(
matched_target
)
weight_quant = quant_config.target_scheme_map[matched_target].get("weights")
input_quant = quant_config.target_scheme_map[matched_target].get(
"input_activations"
)
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise