Enable modelopt gemma3 nvfp4/fp8, make workflow more robust (#22771)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Zhiyu 2025-09-19 15:40:33 -07:00 committed by yewentao256
parent e54a476058
commit 6e94161f94
7 changed files with 82 additions and 22 deletions

View File

@ -11,7 +11,8 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

View File

@ -31,8 +31,11 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
"2.8.0.dev"):
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (envs.VLLM_USE_STANDALONE_COMPILE
and is_torch_equal_or_newer("2.8.0.dev")
and hasattr(torch._inductor, "standalone_compile")):
logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
else:

View File

@ -964,6 +964,9 @@ class ModelConfig:
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides

View File

@ -20,10 +20,10 @@ if has_triton_kernels():
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs)
from triton_kernels.routing import routing
except ModuleNotFoundError:
except (ModuleNotFoundError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
"version is compatible. Error: %s", e)
def triton_kernel_moe_forward(

View File

@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and substring matching.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
if self.exclude_modules is None:
return False
# Check if any excluded module matches the prefix
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Then check substring matching for patterns not caught by exact match
for module in self.exclude_modules:
if (module in prefix
or (prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model."))):
# Skip exact matches already handled above
if (module != prefix and
(module in prefix or
(prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")))):
return True
return False
@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix)):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str,
exclude_modules: list[str]) -> bool:
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and pattern matching.
"""
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Check regex pattern matching for patterns not caught by exact match
import regex as re
for pattern in exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
for pattern in self.exclude_modules:
# Skip patterns that would be caught by exact matching
if '*' in pattern or '.' in pattern:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix, self.exclude_modules)):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):

View File

@ -446,6 +446,22 @@ class Gemma3Model(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue

View File

@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count:
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue