mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:35:01 +08:00
[Bugfix] Make compressed-tensors MoEs respect ignored layers (#28878)
Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
This commit is contained in:
parent
ba1fcd84a7
commit
df01eda4dc
@ -632,6 +632,7 @@ steps:
|
||||
# we can only upgrade after this is resolved
|
||||
# TODO(jerryzh168): resolve the above comment
|
||||
- uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129
|
||||
- uv pip install --system conch-triton-kernels
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from compressed_tensors.quantization import QuantizationType
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensors24,
|
||||
CompressedTensorsLinearMethod,
|
||||
@ -767,3 +768,50 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="This test is not for non-CUDA platforms",
|
||||
)
|
||||
def test_compressed_tensors_moe_ignore_with_model(vllm_runner):
|
||||
"""
|
||||
Integration test for MoE layer ignore functionality with a real model.
|
||||
|
||||
This test would verify that when loading a compressed-tensors quantized
|
||||
MoE model where some MoE layers are in the ignore list, those layers
|
||||
use UnquantizedFusedMoEMethod while non-ignored layers use the
|
||||
quantized method.
|
||||
|
||||
Expected model structure:
|
||||
- Compressed-tensors quantized MoE model (e.g., Mixtral-based)
|
||||
- Config with ignore list containing specific MoE layers
|
||||
- Multiple MoE layers where some are quantized and some are not
|
||||
"""
|
||||
|
||||
# model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3
|
||||
model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable" # CT 12.2
|
||||
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
|
||||
# Check layer 0 MoE (should be quantized)
|
||||
layer_quantized = model.model.layers[0].mlp.experts
|
||||
assert isinstance(layer_quantized, FusedMoE)
|
||||
assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod)
|
||||
|
||||
# Check layer 10 MoE (should be unquantized + ignored)
|
||||
layer_unquantized = model.model.layers[3].mlp.experts
|
||||
assert isinstance(layer_unquantized, FusedMoE)
|
||||
assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
# Verify the model can generate output
|
||||
output = llm.generate_greedy("Hello, my name is", max_tokens=4)
|
||||
assert output
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
@ -41,6 +44,7 @@ __all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
|
||||
@ -158,9 +158,23 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
|
||||
return None
|
||||
|
||||
def _add_fused_moe_to_target_scheme_map(self):
|
||||
"""
|
||||
Helper function to update target_scheme_map
|
||||
since linear layers get fused into FusedMoE
|
||||
targetting 'Linear' needs to also match
|
||||
FusedMoE modules.
|
||||
"""
|
||||
if (
|
||||
"Linear" not in self.target_scheme_map
|
||||
or "FusedMoE" in self.target_scheme_map
|
||||
):
|
||||
return
|
||||
self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
@ -655,25 +669,13 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
to select the CompressedTensorsScheme used for inference.
|
||||
"""
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@kylesayrs): support ignore module names with ct matching utils
|
||||
if should_ignore_layer(
|
||||
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return None
|
||||
# Use the new get_quant_args method to extract QuantizationArgs
|
||||
scheme_dict = self.get_scheme_dict(layer, layer_name)
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
format = None
|
||||
if scheme_dict:
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
format = scheme_dict.get("format")
|
||||
@ -732,6 +734,38 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
|
||||
return scheme
|
||||
|
||||
def get_scheme_dict(
|
||||
self, layer: torch.nn.Module, layer_name: str | None = None
|
||||
) -> dict[str, QuantizationArgs | str | None] | None:
|
||||
"""
|
||||
Extract the QuantizationArgs for a given layer.
|
||||
|
||||
Returns:
|
||||
dict with {
|
||||
"weights": QuantizationArgs,
|
||||
"input_activations": QuantizationArgs | None,
|
||||
"format": str | None
|
||||
} | None
|
||||
"""
|
||||
# TODO (@kylesayrs): support ignore module names with ct matching utils
|
||||
if should_ignore_layer(
|
||||
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return None
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
|
||||
return self.target_scheme_map[matched_target]
|
||||
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
@ -45,9 +46,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
|
||||
@ -113,39 +111,35 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# FusedMoE was made by combining multiple Linears so need to
|
||||
# make sure quantization config for Linear can target it
|
||||
quant_config._add_fused_moe_to_target_scheme_map()
|
||||
unfused_names = [
|
||||
prefix + proj_name
|
||||
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
||||
]
|
||||
# TODO: refactor this to use expert_mapping and check all layer numbers
|
||||
all_scheme_dicts = [
|
||||
quant_config.get_scheme_dict(layer, name) for name in unfused_names
|
||||
]
|
||||
scheme_dict = all_scheme_dicts.pop()
|
||||
|
||||
# multiple schemes found
|
||||
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
||||
raise ValueError(
|
||||
"All MoE projections need to have same "
|
||||
"quantization scheme but found multiple"
|
||||
)
|
||||
|
||||
if scheme_dict is None: # ignored layer
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
# Check if a using "Linear" to select schemes
|
||||
if "Linear" in quant_config.target_scheme_map:
|
||||
matched_target = "Linear"
|
||||
else:
|
||||
# May have instead defined the linear layers in the fused model
|
||||
|
||||
fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"]
|
||||
current_scheme = None
|
||||
for fused_layer in fused_layers:
|
||||
# Check if one of the fused layers are defined in quant_config
|
||||
matched_target = find_matched_target(
|
||||
layer_name=fused_layer,
|
||||
module=layer,
|
||||
targets=quant_config.target_scheme_map.keys(),
|
||||
fused_mapping=quant_config.packed_modules_mapping,
|
||||
)
|
||||
|
||||
# Only valid if down_proj, gate_proj, and up_proj
|
||||
# are mapped to the same quant scheme in the quant_config
|
||||
if current_scheme is None:
|
||||
current_scheme = quant_config.target_scheme_map.get(matched_target)
|
||||
else:
|
||||
assert current_scheme == quant_config.target_scheme_map.get(
|
||||
matched_target
|
||||
)
|
||||
|
||||
weight_quant = quant_config.target_scheme_map[matched_target].get("weights")
|
||||
input_quant = quant_config.target_scheme_map[matched_target].get(
|
||||
"input_activations"
|
||||
)
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
# group_size=None means channelwise
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user