diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 19c4301bd23d5..1c7e62d7aa4c8 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -11,7 +11,8 @@ import pytest import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.config import VllmConfig, current_platform, set_current_vllm_config +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3cc0fc3106f5a..d6bdb31a3c630 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -31,8 +31,11 @@ logger = init_logger(__name__) def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: - if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( - "2.8.0.dev"): + # Use standalone compile only if requested, version is new enough, + # and the symbol actually exists in this PyTorch build. + if (envs.VLLM_USE_STANDALONE_COMPILE + and is_torch_equal_or_newer("2.8.0.dev") + and hasattr(torch._inductor, "standalone_compile")): logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: diff --git a/vllm/config/model.py b/vllm/config/model.py index 21457d3660a23..4e847922b61e6 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -964,6 +964,9 @@ class ModelConfig: "modelopt", "modelopt_fp4", "petit_nvfp4", + # Ensure heavy backends are probed last to avoid unnecessary + # imports during override detection (e.g., MXFP4 imports Triton) + "mxfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 08a9b34a42457..f12d3807517ff 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -20,10 +20,10 @@ if has_triton_kernels(): from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, matmul_ogs) from triton_kernels.routing import routing - except ModuleNotFoundError: + except (ModuleNotFoundError, AttributeError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " - "version is compatible.") + "version is compatible. Error: %s", e) def triton_kernel_moe_forward( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7eac40825ac33..1083f398a3a20 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig): def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. + Handles both exact matching (for fused layers) and substring matching. This method handles both regular models and multimodal models that use the language_model prefix. For multimodal models, it checks if the @@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig): if self.exclude_modules is None: return False - # Check if any excluded module matches the prefix + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping): + return True + + # Then check substring matching for patterns not caught by exact match for module in self.exclude_modules: - if (module in prefix - or (prefix.startswith("language_model.") - and module in prefix.removeprefix("language_model."))): + # Skip exact matches already handled above + if (module != prefix and + (module in prefix or + (prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model.")))): return True return False @@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules, - self.packed_modules_mapping) - or self.is_layer_excluded(prefix)): + if self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() + # Check if this is a vision model layer that should not be quantized + if ("vision_tower" in prefix or "vision_model" in prefix): return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): @@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig): return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) - def is_layer_excluded(self, prefix: str, - exclude_modules: list[str]) -> bool: + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + Handles both exact matching (for fused layers) and pattern matching. + """ + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping): + return True + + # Check regex pattern matching for patterns not caught by exact match import regex as re - for pattern in exclude_modules: - regex_str = pattern.replace('.', r'\.').replace('*', r'.*') - if re.fullmatch(regex_str, prefix): - return True + for pattern in self.exclude_modules: + # Skip patterns that would be caught by exact matching + if '*' in pattern or '.' in pattern: + regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if re.fullmatch(regex_str, prefix): + return True return False def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules, - self.packed_modules_mapping) - or self.is_layer_excluded(prefix, self.exclude_modules)): + if self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() + # Check if this is a vision model layer that should not be quantized + if ("vision_tower" in prefix or "vision_model" in prefix): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 1263e3049a14a..7246308d59028 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -446,6 +446,22 @@ class Gemma3Model(nn.Module): weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue + + # Check if this is a scale parameter that needs remapping first + if name.endswith( + (".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + # Try to remap the scale name first + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + # Successfully remapped, use the remapped name + param = params_dict[remapped_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + # If remapping failed, continue with normal processing + for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 3630f59f53e0a..eb49d6d2c3350 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs @@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module): if layer_idx >= layer_count: continue + # Check if this is a scale parameter that needs remapping first + if name.endswith( + (".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + # Try to remap the scale name first + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + # Successfully remapped, use the remapped name + param = params_dict[remapped_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + # If remapping failed, continue with normal processing + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue