mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 12:36:32 +08:00
Enable modelopt gemma3 nvfp4/fp8, make workflow more robust (#22771)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
e54a476058
commit
6e94161f94
@ -11,7 +11,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
|
||||
@ -31,8 +31,11 @@ logger = init_logger(__name__)
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
|
||||
"2.8.0.dev"):
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
|
||||
@ -964,6 +964,9 @@ class ModelConfig:
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"petit_nvfp4",
|
||||
# Ensure heavy backends are probed last to avoid unnecessary
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
"mxfp4",
|
||||
]
|
||||
quantization_methods = [
|
||||
q for q in supported_quantization if q not in overrides
|
||||
|
||||
@ -20,10 +20,10 @@ if has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||
matmul_ogs)
|
||||
from triton_kernels.routing import routing
|
||||
except ModuleNotFoundError:
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
"version is compatible. Error: %s", e)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
|
||||
@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and substring matching.
|
||||
|
||||
This method handles both regular models and multimodal models that use
|
||||
the language_model prefix. For multimodal models, it checks if the
|
||||
@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
if self.exclude_modules is None:
|
||||
return False
|
||||
|
||||
# Check if any excluded module matches the prefix
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Then check substring matching for patterns not caught by exact match
|
||||
for module in self.exclude_modules:
|
||||
if (module in prefix
|
||||
or (prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model."))):
|
||||
# Skip exact matches already handled above
|
||||
if (module != prefix and
|
||||
(module in prefix or
|
||||
(prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model.")))):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix)):
|
||||
if self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and pattern matching.
|
||||
"""
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Check regex pattern matching for patterns not caught by exact match
|
||||
import regex as re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
for pattern in self.exclude_modules:
|
||||
# Skip patterns that would be caught by exact matching
|
||||
if '*' in pattern or '.' in pattern:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
if self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
|
||||
@ -446,6 +446,22 @@ class Gemma3Model(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Check if this is a scale parameter that needs remapping first
|
||||
if name.endswith(
|
||||
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
# Try to remap the scale name first
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is not None and remapped_name in params_dict:
|
||||
# Successfully remapped, use the remapped name
|
||||
param = params_dict[remapped_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(remapped_name)
|
||||
continue
|
||||
# If remapping failed, continue with normal processing
|
||||
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
||||
@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
|
||||
@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module):
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
# Check if this is a scale parameter that needs remapping first
|
||||
if name.endswith(
|
||||
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
# Try to remap the scale name first
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is not None and remapped_name in params_dict:
|
||||
# Successfully remapped, use the remapped name
|
||||
param = params_dict[remapped_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(remapped_name)
|
||||
continue
|
||||
# If remapping failed, continue with normal processing
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user