mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 13:35:52 +08:00
Enable modelopt gemma3 nvfp4/fp8, make workflow more robust (#22771)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
711e912946
commit
431535b522
@ -11,7 +11,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
|||||||
@ -31,8 +31,11 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||||
if compilation_config.use_inductor:
|
if compilation_config.use_inductor:
|
||||||
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
|
# Use standalone compile only if requested, version is new enough,
|
||||||
"2.8.0.dev"):
|
# and the symbol actually exists in this PyTorch build.
|
||||||
|
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||||
|
and is_torch_equal_or_newer("2.8.0.dev")
|
||||||
|
and hasattr(torch._inductor, "standalone_compile")):
|
||||||
logger.debug("Using InductorStandaloneAdaptor")
|
logger.debug("Using InductorStandaloneAdaptor")
|
||||||
return InductorStandaloneAdaptor()
|
return InductorStandaloneAdaptor()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -964,6 +964,9 @@ class ModelConfig:
|
|||||||
"modelopt",
|
"modelopt",
|
||||||
"modelopt_fp4",
|
"modelopt_fp4",
|
||||||
"petit_nvfp4",
|
"petit_nvfp4",
|
||||||
|
# Ensure heavy backends are probed last to avoid unnecessary
|
||||||
|
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||||
|
"mxfp4",
|
||||||
]
|
]
|
||||||
quantization_methods = [
|
quantization_methods = [
|
||||||
q for q in supported_quantization if q not in overrides
|
q for q in supported_quantization if q not in overrides
|
||||||
|
|||||||
@ -20,10 +20,10 @@ if has_triton_kernels():
|
|||||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||||
matmul_ogs)
|
matmul_ogs)
|
||||||
from triton_kernels.routing import routing
|
from triton_kernels.routing import routing
|
||||||
except ModuleNotFoundError:
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
"version is compatible.")
|
"version is compatible. Error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
def triton_kernel_moe_forward(
|
def triton_kernel_moe_forward(
|
||||||
|
|||||||
@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
def is_layer_excluded(self, prefix: str) -> bool:
|
def is_layer_excluded(self, prefix: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a layer should be excluded from quantization.
|
Check if a layer should be excluded from quantization.
|
||||||
|
Handles both exact matching (for fused layers) and substring matching.
|
||||||
|
|
||||||
This method handles both regular models and multimodal models that use
|
This method handles both regular models and multimodal models that use
|
||||||
the language_model prefix. For multimodal models, it checks if the
|
the language_model prefix. For multimodal models, it checks if the
|
||||||
@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
if self.exclude_modules is None:
|
if self.exclude_modules is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if any excluded module matches the prefix
|
# First check exact matching with fused layer support
|
||||||
|
if is_layer_skipped(prefix, self.exclude_modules,
|
||||||
|
self.packed_modules_mapping):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Then check substring matching for patterns not caught by exact match
|
||||||
for module in self.exclude_modules:
|
for module in self.exclude_modules:
|
||||||
if (module in prefix
|
# Skip exact matches already handled above
|
||||||
or (prefix.startswith("language_model.")
|
if (module != prefix and
|
||||||
and module in prefix.removeprefix("language_model."))):
|
(module in prefix or
|
||||||
|
(prefix.startswith("language_model.")
|
||||||
|
and module in prefix.removeprefix("language_model.")))):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
if self.is_layer_excluded(prefix):
|
||||||
self.packed_modules_mapping)
|
return UnquantizedLinearMethod()
|
||||||
or self.is_layer_excluded(prefix)):
|
# Check if this is a vision model layer that should not be quantized
|
||||||
|
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptFp8LinearMethod(self)
|
return ModelOptFp8LinearMethod(self)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||||
exclude_modules, group_size)
|
exclude_modules, group_size)
|
||||||
|
|
||||||
def is_layer_excluded(self, prefix: str,
|
def is_layer_excluded(self, prefix: str) -> bool:
|
||||||
exclude_modules: list[str]) -> bool:
|
"""
|
||||||
|
Check if a layer should be excluded from quantization.
|
||||||
|
Handles both exact matching (for fused layers) and pattern matching.
|
||||||
|
"""
|
||||||
|
# First check exact matching with fused layer support
|
||||||
|
if is_layer_skipped(prefix, self.exclude_modules,
|
||||||
|
self.packed_modules_mapping):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check regex pattern matching for patterns not caught by exact match
|
||||||
import regex as re
|
import regex as re
|
||||||
for pattern in exclude_modules:
|
for pattern in self.exclude_modules:
|
||||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
# Skip patterns that would be caught by exact matching
|
||||||
if re.fullmatch(regex_str, prefix):
|
if '*' in pattern or '.' in pattern:
|
||||||
return True
|
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||||
|
if re.fullmatch(regex_str, prefix):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
if self.is_layer_excluded(prefix):
|
||||||
self.packed_modules_mapping)
|
return UnquantizedLinearMethod()
|
||||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
# Check if this is a vision model layer that should not be quantized
|
||||||
|
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptNvFp4LinearMethod(self)
|
return ModelOptNvFp4LinearMethod(self)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
|
|||||||
@ -446,6 +446,22 @@ class Gemma3Model(nn.Module):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if this is a scale parameter that needs remapping first
|
||||||
|
if name.endswith(
|
||||||
|
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||||
|
# Try to remap the scale name first
|
||||||
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if remapped_name is not None and remapped_name in params_dict:
|
||||||
|
# Successfully remapped, use the remapped name
|
||||||
|
param = params_dict[remapped_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(remapped_name)
|
||||||
|
continue
|
||||||
|
# If remapping failed, continue with normal processing
|
||||||
|
|
||||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||||
if shard_name not in name:
|
if shard_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
|
||||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||||
|
|
||||||
@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module):
|
|||||||
if layer_idx >= layer_count:
|
if layer_idx >= layer_count:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if this is a scale parameter that needs remapping first
|
||||||
|
if name.endswith(
|
||||||
|
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||||
|
# Try to remap the scale name first
|
||||||
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if remapped_name is not None and remapped_name in params_dict:
|
||||||
|
# Successfully remapped, use the remapped name
|
||||||
|
param = params_dict[remapped_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(remapped_name)
|
||||||
|
continue
|
||||||
|
# If remapping failed, continue with normal processing
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user