diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d14b524b793a..375645fde747 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -632,6 +632,7 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 53min diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 31b65189b5ec..412b21328a32 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -10,6 +10,7 @@ import torch from compressed_tensors.quantization import QuantizationType from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensors24, CompressedTensorsLinearMethod, @@ -767,3 +768,50 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner): output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="This test is not for non-CUDA platforms", +) +def test_compressed_tensors_moe_ignore_with_model(vllm_runner): + """ + Integration test for MoE layer ignore functionality with a real model. + + This test would verify that when loading a compressed-tensors quantized + MoE model where some MoE layers are in the ignore list, those layers + use UnquantizedFusedMoEMethod while non-ignored layers use the + quantized method. + + Expected model structure: + - Compressed-tensors quantized MoE model (e.g., Mixtral-based) + - Config with ignore list containing specific MoE layers + - Multiple MoE layers where some are quantized and some are not + """ + + # model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3 + model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable" # CT 12.2 + + with vllm_runner(model_path, enforce_eager=True) as llm: + + def check_model(model): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod, + ) + + # Check layer 0 MoE (should be quantized) + layer_quantized = model.model.layers[0].mlp.experts + assert isinstance(layer_quantized, FusedMoE) + assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod) + + # Check layer 10 MoE (should be unquantized + ignored) + layer_unquantized = model.model.layers[3].mlp.experts + assert isinstance(layer_unquantized, FusedMoE) + assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod) + + llm.apply_model(check_model) + + # Verify the model can generate output + output = llm.generate_greedy("Hello, my name is", max_tokens=4) + assert output diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 53d98d0650b4..669abcb3d6ff 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON @@ -41,6 +44,7 @@ __all__ = [ "FusedMoE", "FusedMoEConfig", "FusedMoEMethodBase", + "UnquantizedFusedMoEMethod", "FusedMoeWeightScaleSupported", "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 2800f90ce0b6..7f61746a4e45 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -158,9 +158,23 @@ class CompressedTensorsConfig(QuantizationConfig): if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self, layer) + return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return None + def _add_fused_moe_to_target_scheme_map(self): + """ + Helper function to update target_scheme_map + since linear layers get fused into FusedMoE + targetting 'Linear' needs to also match + FusedMoE modules. + """ + if ( + "Linear" not in self.target_scheme_map + or "FusedMoE" in self.target_scheme_map + ): + return + self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"] + @classmethod def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": ignore: list[str] = cast(list[str], config.get("ignore", [])) @@ -655,25 +669,13 @@ class CompressedTensorsConfig(QuantizationConfig): to select the CompressedTensorsScheme used for inference. """ - # Find the "target" in the compressed-tensors config - # that our layer conforms to. - # TODO (@kylesayrs): support ignore module names with ct matching utils - if should_ignore_layer( - layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping - ): - return None + # Use the new get_quant_args method to extract QuantizationArgs + scheme_dict = self.get_scheme_dict(layer, layer_name) - # Will be empty for models with only sparsity - weight_quant = input_quant = None - if self.target_scheme_map: - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping, - ) - - scheme_dict = self.target_scheme_map[matched_target] + weight_quant = None + input_quant = None + format = None + if scheme_dict: weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") format = scheme_dict.get("format") @@ -732,6 +734,38 @@ class CompressedTensorsConfig(QuantizationConfig): logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) return scheme + def get_scheme_dict( + self, layer: torch.nn.Module, layer_name: str | None = None + ) -> dict[str, QuantizationArgs | str | None] | None: + """ + Extract the QuantizationArgs for a given layer. + + Returns: + dict with { + "weights": QuantizationArgs, + "input_activations": QuantizationArgs | None, + "format": str | None + } | None + """ + # TODO (@kylesayrs): support ignore module names with ct matching utils + if should_ignore_layer( + layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return None + + # Will be empty for models with only sparsity + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping, + ) + + return self.target_scheme_map[matched_target] + + return None + def get_cache_scale(self, name: str) -> str | None: """ Check whether the param name matches the format for k/v cache scales diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 71d7de97d4a1..c7dfd1787cc8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -45,9 +46,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, ) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, -) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -113,39 +111,35 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 layer: torch.nn.Module, + prefix: str, ) -> "CompressedTensorsMoEMethod": + # FusedMoE was made by combining multiple Linears so need to + # make sure quantization config for Linear can target it + quant_config._add_fused_moe_to_target_scheme_map() + unfused_names = [ + prefix + proj_name + for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] + ] + # TODO: refactor this to use expert_mapping and check all layer numbers + all_scheme_dicts = [ + quant_config.get_scheme_dict(layer, name) for name in unfused_names + ] + scheme_dict = all_scheme_dicts.pop() + + # multiple schemes found + if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]): + raise ValueError( + "All MoE projections need to have same " + "quantization scheme but found multiple" + ) + + if scheme_dict is None: # ignored layer + return UnquantizedFusedMoEMethod(layer.moe_config) + # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - # Check if a using "Linear" to select schemes - if "Linear" in quant_config.target_scheme_map: - matched_target = "Linear" - else: - # May have instead defined the linear layers in the fused model - - fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"] - current_scheme = None - for fused_layer in fused_layers: - # Check if one of the fused layers are defined in quant_config - matched_target = find_matched_target( - layer_name=fused_layer, - module=layer, - targets=quant_config.target_scheme_map.keys(), - fused_mapping=quant_config.packed_modules_mapping, - ) - - # Only valid if down_proj, gate_proj, and up_proj - # are mapped to the same quant scheme in the quant_config - if current_scheme is None: - current_scheme = quant_config.target_scheme_map.get(matched_target) - else: - assert current_scheme == quant_config.target_scheme_map.get( - matched_target - ) - - weight_quant = quant_config.target_scheme_map[matched_target].get("weights") - input_quant = quant_config.target_scheme_map[matched_target].get( - "input_activations" - ) + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): # group_size=None means channelwise