Enable modelopt gemma3 nvfp4/fp8, make workflow more robust (#22771)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Zhiyu 2025-09-19 15:40:33 -07:00 committed by GitHub
parent 711e912946
commit 431535b522
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 82 additions and 22 deletions

View File

@ -11,7 +11,8 @@ import pytest
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

View File

@ -31,8 +31,11 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor: if compilation_config.use_inductor:
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( # Use standalone compile only if requested, version is new enough,
"2.8.0.dev"): # and the symbol actually exists in this PyTorch build.
if (envs.VLLM_USE_STANDALONE_COMPILE
and is_torch_equal_or_newer("2.8.0.dev")
and hasattr(torch._inductor, "standalone_compile")):
logger.debug("Using InductorStandaloneAdaptor") logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor() return InductorStandaloneAdaptor()
else: else:

View File

@ -964,6 +964,9 @@ class ModelConfig:
"modelopt", "modelopt",
"modelopt_fp4", "modelopt_fp4",
"petit_nvfp4", "petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
] ]
quantization_methods = [ quantization_methods = [
q for q in supported_quantization if q not in overrides q for q in supported_quantization if q not in overrides

View File

@ -20,10 +20,10 @@ if has_triton_kernels():
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs) matmul_ogs)
from triton_kernels.routing import routing from triton_kernels.routing import routing
except ModuleNotFoundError: except (ModuleNotFoundError, AttributeError) as e:
logger.error( logger.error(
"Failed to import Triton kernels. Please make sure your triton " "Failed to import Triton kernels. Please make sure your triton "
"version is compatible.") "version is compatible. Error: %s", e)
def triton_kernel_moe_forward( def triton_kernel_moe_forward(

View File

@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
def is_layer_excluded(self, prefix: str) -> bool: def is_layer_excluded(self, prefix: str) -> bool:
""" """
Check if a layer should be excluded from quantization. Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and substring matching.
This method handles both regular models and multimodal models that use This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the the language_model prefix. For multimodal models, it checks if the
@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
if self.exclude_modules is None: if self.exclude_modules is None:
return False return False
# Check if any excluded module matches the prefix # First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Then check substring matching for patterns not caught by exact match
for module in self.exclude_modules: for module in self.exclude_modules:
if (module in prefix # Skip exact matches already handled above
or (prefix.startswith("language_model.") if (module != prefix and
and module in prefix.removeprefix("language_model."))): (module in prefix or
(prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")))):
return True return True
return False return False
@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules, if self.is_layer_excluded(prefix):
self.packed_modules_mapping) return UnquantizedLinearMethod()
or self.is_layer_excluded(prefix)): # Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self) return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size) exclude_modules, group_size)
def is_layer_excluded(self, prefix: str, def is_layer_excluded(self, prefix: str) -> bool:
exclude_modules: list[str]) -> bool: """
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and pattern matching.
"""
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Check regex pattern matching for patterns not caught by exact match
import regex as re import regex as re
for pattern in exclude_modules: for pattern in self.exclude_modules:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*') # Skip patterns that would be caught by exact matching
if re.fullmatch(regex_str, prefix): if '*' in pattern or '.' in pattern:
return True regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
return False return False
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules, if self.is_layer_excluded(prefix):
self.packed_modules_mapping) return UnquantizedLinearMethod()
or self.is_layer_excluded(prefix, self.exclude_modules)): # Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self) return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):

View File

@ -446,6 +446,22 @@ class Gemma3Model(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(scale_name) loaded_params.add(scale_name)
continue continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, shard_name, shard_id) in stacked_params_mapping: for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue

View File

@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count: if layer_idx >= layer_count:
continue continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue