mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 05:15:42 +08:00
Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
parent
5de8d9f111
commit
4afe687a82
@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if this quantization method uses 'weight_scale_2' pattern
|
||||||
|
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses that use the
|
||||||
|
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
|
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||||
if "ModelOpt" in quant_method_name:
|
if "ModelOpt" in quant_method_name:
|
||||||
if ('weight_scale_2' in weight_name
|
# Determine per-tensor weight scale patterns based on variant
|
||||||
or 'input_scale' in weight_name):
|
# Use the dedicated method instead of brittle string matching
|
||||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
|
||||||
param=param,
|
)
|
||||||
loaded_weight=loaded_weight,
|
|
||||||
expert_id=expert_id)
|
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
|
||||||
|
per_tensor_conditions = (
|
||||||
|
"weight_scale_2" in weight_name if uses_weight_scale_2 else
|
||||||
|
"weight_scale" in weight_name) or "input_scale" in weight_name
|
||||||
|
|
||||||
|
if per_tensor_conditions:
|
||||||
|
self._load_per_tensor_weight_scale(
|
||||||
|
shard_id=shard_id,
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
elif "weight" in weight_name:
|
elif "weight" in weight_name:
|
||||||
self._load_model_weight_or_group_weight_scale(
|
self._load_model_weight_or_group_weight_scale(
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
@ -1558,3 +1579,7 @@ direct_register_custom_op(
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
|
||||||
|
# to avoid expensive runtime reflection in model loading code
|
||||||
|
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||||
|
|||||||
@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
is_checkpoint_fp8_serialized: bool = False,
|
is_checkpoint_fp8_serialized: bool = False,
|
||||||
|
kv_cache_quant_method: Optional[str] = None,
|
||||||
|
exclude_modules: Optional[list[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
|
self.kv_cache_quant_method = kv_cache_quant_method
|
||||||
|
self.exclude_modules = exclude_modules
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
||||||
" the format is experimental and could change.")
|
" the format is experimental and could change.")
|
||||||
@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
||||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||||
quant_method = quant_config["quant_algo"]
|
quant_method = quant_config["quant_algo"]
|
||||||
|
kv_cache_quant_method = cls.get_from_keys(
|
||||||
|
config, ["quantization"]).get("kv_cache_quant_algo")
|
||||||
|
exclude_modules = cls.get_from_keys(
|
||||||
|
config, ["quantization"]).get("exclude_modules")
|
||||||
|
|
||||||
if quant_method not in QUANT_ALGOS:
|
if quant_method not in QUANT_ALGOS:
|
||||||
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
|
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
|
||||||
" quantizations in vLLM. Please check the "
|
" quantizations in vLLM. Please check the "
|
||||||
@ -76,27 +85,51 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
"quant configuration.")
|
"quant configuration.")
|
||||||
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
|
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
|
||||||
|
|
||||||
return cls(is_checkpoint_fp8_serialized)
|
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
|
||||||
|
exclude_modules)
|
||||||
|
|
||||||
|
def is_layer_excluded(self, prefix: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a layer should be excluded from quantization.
|
||||||
|
|
||||||
|
This method handles both regular models and multimodal models that use
|
||||||
|
the language_model prefix. For multimodal models, it checks if the
|
||||||
|
module name (without the language_model prefix) is in the exclude list.
|
||||||
|
"""
|
||||||
|
if self.exclude_modules is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if any excluded module matches the prefix
|
||||||
|
for module in self.exclude_modules:
|
||||||
|
if (module in prefix
|
||||||
|
or (prefix.startswith("language_model.")
|
||||||
|
and module in prefix.removeprefix("language_model."))):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
if self.is_layer_excluded(prefix):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return ModelOptFp8LinearMethod(self)
|
return ModelOptFp8LinearMethod(self)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return ModelOptFp8MoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8LinearMethod(LinearMethodBase):
|
class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Model Optimizer static quantization.
|
"""Linear method for Model Optimizer static quantization.
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
activation scale. Future support might be added for dynamic
|
activation scale. Future support might be added for dynamic
|
||||||
scales.
|
scales.
|
||||||
|
|
||||||
Limitations:
|
Limitations:
|
||||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||||
2. Only support float8_e4m3fn datatype
|
2. Only support float8_e4m3fn datatype
|
||||||
Args: quant_config: The ModelOpt quantization config.
|
Args: quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method for ModelOpt FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
|
activation scale.
|
||||||
|
Args:
|
||||||
|
quant_config: The ModelOpt quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: ModelOptFp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
cutlass_fp8_supported)
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
|
||||||
|
# Use FP8 dtype if checkpoint is serialized
|
||||||
|
weight_dtype = (torch.float8_e4m3fn
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||||
|
params_dtype)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
w13_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
dtype=weight_dtype),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
|
||||||
|
w2_weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=weight_dtype),
|
||||||
|
input_dim=2,
|
||||||
|
output_dim=1,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full(
|
||||||
|
(num_experts, 2),
|
||||||
|
1.0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
w2_weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
# Set weight loader attributes for scales
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
|
||||||
|
# INPUT SCALES - Per-tensor scaling for ModelOpt
|
||||||
|
w13_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
w2_input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
"""Process FP8 MoE weights after loading from serialized checkpoint.
|
||||||
|
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
||||||
|
"""
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(layer.w13_weight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
per_tensor_dequantize)
|
||||||
|
|
||||||
|
# Handle scale parameters
|
||||||
|
if hasattr(layer,
|
||||||
|
"w13_weight_scale") and layer.w13_weight_scale is not None:
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max of the w1 and w3 scales
|
||||||
|
# then dequant and requant each expert.
|
||||||
|
if layer.w13_weight_scale.dim() == 2:
|
||||||
|
|
||||||
|
# Get the maximum scale across w1 and w3 for each expert
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
|
||||||
|
# Requantize each expert's weights using the combined scale
|
||||||
|
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
|
||||||
|
# where the first intermediate_size rows are w1, the next are w3
|
||||||
|
intermediate_size = layer.w13_weight.shape[1] // 2
|
||||||
|
for expert_id in range(layer.w13_weight.shape[0]):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2): # w1 and w3
|
||||||
|
# Dequantize using the original scale for this shard
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][start:start +
|
||||||
|
intermediate_size, :],
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id],
|
||||||
|
)
|
||||||
|
# Requantize using the combined max scale
|
||||||
|
|
||||||
|
(
|
||||||
|
layer.w13_weight[expert_id][start:start +
|
||||||
|
intermediate_size, :],
|
||||||
|
_,
|
||||||
|
) = scaled_fp8_quant(dq_weight,
|
||||||
|
max_w13_scales[expert_id])
|
||||||
|
|
||||||
|
start += intermediate_size
|
||||||
|
|
||||||
|
# Update the scale parameter to be per-expert
|
||||||
|
layer.w13_weight_scale = Parameter(max_w13_scales,
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
if hasattr(layer,
|
||||||
|
"w2_weight_scale") and layer.w2_weight_scale is not None:
|
||||||
|
layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
# Input scales must be equal for each expert in fp8 MoE layers.
|
||||||
|
if hasattr(layer,
|
||||||
|
"w13_input_scale") and layer.w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
if hasattr(layer,
|
||||||
|
"w2_input_scale") and layer.w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||||
|
|
||||||
|
# Expert selection
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_experts)
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=False,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4Config(QuantizationConfig):
|
class ModelOptNvFp4Config(QuantizationConfig):
|
||||||
"""Config class for ModelOpt FP4."""
|
"""Config class for ModelOpt FP4."""
|
||||||
|
|
||||||
@ -274,7 +524,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|||||||
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Model Optimizer NVFP4.
|
"""Linear method for Model Optimizer NVFP4.
|
||||||
Supports loading NVFP4 checkpoints with the following structure:
|
Supports loading NVFP4 checkpoints with the following structure:
|
||||||
|
|
||||||
input_scale: torch.float32, scalar ,
|
input_scale: torch.float32, scalar ,
|
||||||
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
|
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
|
||||||
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
|
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
|
||||||
@ -455,7 +705,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||||
"""
|
"""
|
||||||
MoE Method for FP4 Quantization.
|
MoE Method for FP4 Quantization.
|
||||||
Args:
|
Args:
|
||||||
quant_config: NVFP4 Quant Config
|
quant_config: NVFP4 Quant Config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
" quantization. Please use Blackwell and"
|
" quantization. Please use Blackwell and"
|
||||||
" above.")
|
" above.")
|
||||||
|
|
||||||
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
|
"""
|
||||||
|
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|||||||
@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
modelopt_scale_names = [
|
modelopt_scale_names = [
|
||||||
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
||||||
]
|
]
|
||||||
|
# Also support qkv_proj scale parameters (from stacked parameter processing)
|
||||||
|
qkv_proj_scale_names = [
|
||||||
|
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
|
||||||
|
]
|
||||||
for scale_name in possible_scale_names:
|
for scale_name in possible_scale_names:
|
||||||
if name.endswith(scale_name):
|
if name.endswith(scale_name):
|
||||||
if any(mo_scale_name in name
|
if any(mo_scale_name in name
|
||||||
@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
remapped_name = name.replace(
|
remapped_name = name.replace(
|
||||||
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||||
f".self_attn.attn{scale_name}")
|
f".self_attn.attn{scale_name}")
|
||||||
|
elif any(qkv_scale_name in name
|
||||||
|
for qkv_scale_name in qkv_proj_scale_names):
|
||||||
|
# Handle qkv_proj scale parameters
|
||||||
|
remapped_name = name.replace(
|
||||||
|
f".self_attn.qkv_proj{scale_name}",
|
||||||
|
f".self_attn.attn{scale_name}")
|
||||||
else:
|
else:
|
||||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
|
|||||||
@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
|
||||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||||
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
||||||
@ -432,12 +433,24 @@ class Llama4Model(LlamaModel):
|
|||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name or "experts" in name:
|
if weight_name not in name or "experts" in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
# This check is for ModelOpt ckpts with kv cache quant enabled
|
||||||
|
if not (name.endswith(
|
||||||
|
(".k_scale", ".v_scale")) and "self_attn" in name):
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if name.endswith("scale") and "expert" not in name:
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = getattr(param, "weight_loader",
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
default_weight_loader)
|
||||||
|
if weight_loader == default_weight_loader:
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
|
|||||||
if not moe_loaded:
|
if not moe_loaded:
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Handle flat expert scale parameters that
|
||||||
|
# don't match per-expert patterns
|
||||||
|
if ("experts." in name and ("w13_input_scale" in name
|
||||||
|
or "w13_weight_scale" in name
|
||||||
|
or "w2_input_scale" in name
|
||||||
|
or "w2_weight_scale" in name)):
|
||||||
|
# These are flat expert scales that apply to all experts
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
|
||||||
|
# Check for MoE-specific loading support via
|
||||||
|
# attribute instead of expensive runtime reflection
|
||||||
|
supports_moe = getattr(weight_loader,
|
||||||
|
'supports_moe_loading', False)
|
||||||
|
|
||||||
|
if supports_moe:
|
||||||
|
# This is a MoE weight loader
|
||||||
|
if "w13_" in name:
|
||||||
|
shard_id = "w1"
|
||||||
|
elif "w2_" in name:
|
||||||
|
shard_id = "w2"
|
||||||
|
else:
|
||||||
|
shard_id = "w1"
|
||||||
|
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=0)
|
||||||
|
else:
|
||||||
|
# Regular weight loader (handles both
|
||||||
|
# param.weight_loader and default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
SupportsPP):
|
SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
qkv_weight = torch.cat(weight, dim=0)
|
qkv_weight = torch.cat(weight, dim=0)
|
||||||
yield key, qkv_weight
|
yield key, qkv_weight
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
|
||||||
torch.Tensor]]) -> set[str]:
|
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
|
||||||
|
format."""
|
||||||
|
if name.startswith("model."):
|
||||||
|
# Handle expert scale parameters with flat naming
|
||||||
|
if "feed_forward.experts." in name and ("_input_scale" in name or
|
||||||
|
"_weight_scale" in name):
|
||||||
|
renamed = name.replace("model.", "language_model.model.", 1)
|
||||||
|
# Map checkpoint naming to vLLM's expected naming
|
||||||
|
if "down_proj_input_scale" in renamed:
|
||||||
|
return renamed.replace("down_proj_input_scale",
|
||||||
|
"w2_input_scale")
|
||||||
|
elif "down_proj_weight_scale" in renamed:
|
||||||
|
return renamed.replace("down_proj_weight_scale",
|
||||||
|
"w2_weight_scale")
|
||||||
|
elif "gate_up_proj_input_scale" in renamed:
|
||||||
|
return renamed.replace("gate_up_proj_input_scale",
|
||||||
|
"w13_input_scale")
|
||||||
|
elif "gate_up_proj_weight_scale" in renamed:
|
||||||
|
return renamed.replace("gate_up_proj_weight_scale",
|
||||||
|
"w13_weight_scale")
|
||||||
|
return renamed
|
||||||
|
|
||||||
stacked_params_mapping = [
|
# Handle attention scale parameters
|
||||||
# (param_name, shard_name, shard_id)
|
elif "self_attn." in name and (".k_scale" in name
|
||||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
or ".v_scale" in name):
|
||||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
renamed = name.replace("model.", "language_model.model.", 1)
|
||||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
if ".k_proj.k_scale" in renamed:
|
||||||
]
|
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
|
||||||
params_dict = dict(self.named_parameters())
|
elif ".v_proj.v_scale" in renamed:
|
||||||
updated_params: set[str] = set()
|
return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
|
||||||
|
return renamed
|
||||||
|
|
||||||
# language_model is an Llama4ForCausalLM instance. We load it's
|
# Standard model.* to language_model.model.* renaming
|
||||||
# using llama4's load_weights routine.
|
return name.replace("model.", "language_model.model.", 1)
|
||||||
language_model_weights, other_weights = self.separate_weights(
|
|
||||||
weights, prefix="language_model.")
|
elif name.startswith("lm_head.weight"):
|
||||||
loader = AutoWeightsLoader(self)
|
return name.replace("lm_head.weight",
|
||||||
loaded_language_model_params = loader.load_weights(
|
"language_model.lm_head.weight")
|
||||||
language_model_weights)
|
|
||||||
assert loaded_language_model_params is not None
|
return name
|
||||||
updated_params.update(loaded_language_model_params)
|
|
||||||
|
def _separate_and_rename_weights(
|
||||||
|
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||||
|
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
|
||||||
|
"""Rename weights and separate them into language_model and other
|
||||||
|
weights."""
|
||||||
|
language_model_weights = []
|
||||||
|
other_weights = []
|
||||||
|
|
||||||
|
for name, weight in weights:
|
||||||
|
renamed = self._rename_weight_for_modelopt_checkpoint(name)
|
||||||
|
|
||||||
|
if renamed.startswith("language_model."):
|
||||||
|
language_model_weights.append((renamed, weight))
|
||||||
|
else:
|
||||||
|
other_weights.append((renamed, weight))
|
||||||
|
|
||||||
|
return language_model_weights, other_weights
|
||||||
|
|
||||||
|
def _handle_expert_scale_broadcasting(
|
||||||
|
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
|
||||||
|
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
|
||||||
|
"""Handle expert scale parameters that need broadcasting.
|
||||||
|
|
||||||
|
ModelOpt checkpoints use a single value tensor scalar for BMM style
|
||||||
|
experts, vLLM expects the scale to be broadcasted across all experts.
|
||||||
|
"""
|
||||||
|
regular_weights = []
|
||||||
|
expert_scale_weights = []
|
||||||
|
updated_params = set()
|
||||||
|
|
||||||
|
for name, weight in weights:
|
||||||
|
# Check if this is an expert scale parameter that needs broadcasting
|
||||||
|
if ("feed_forward.experts." in name and "scale" in name
|
||||||
|
and ".shared_expert" not in name):
|
||||||
|
if name in params_dict:
|
||||||
|
param = params_dict[name]
|
||||||
|
if (hasattr(param, 'data') and param.data.numel() > 1
|
||||||
|
and weight.numel() == 1):
|
||||||
|
# Broadcast single value to all experts
|
||||||
|
param.data.fill_(weight.item())
|
||||||
|
updated_params.add(name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert_scale_weights.append((name, weight))
|
||||||
|
else:
|
||||||
|
regular_weights.append((name, weight))
|
||||||
|
|
||||||
|
return regular_weights, expert_scale_weights, updated_params
|
||||||
|
|
||||||
|
def _load_other_weights(self, other_weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]],
|
||||||
|
params_dict: dict,
|
||||||
|
stacked_params_mapping: list) -> set[str]:
|
||||||
|
"""Load non-language-model weights with stacking support."""
|
||||||
|
updated_params = set()
|
||||||
|
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
other_weights = self._consolidate_qkv_weights(other_weights)
|
other_weights = self._consolidate_qkv_weights(other_weights)
|
||||||
|
|
||||||
for name, loaded_weight in other_weights:
|
for name, loaded_weight in other_weights:
|
||||||
|
# Try stacked parameter mapping first
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name or self.use_data_parallel:
|
if weight_name not in name or self.use_data_parallel:
|
||||||
continue
|
continue
|
||||||
@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
# Use regular weight loading
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
updated_params.add(name)
|
updated_params.add(name)
|
||||||
|
|
||||||
|
return updated_params
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||||
|
# Shared expert gate_up_proj stacking
|
||||||
|
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
||||||
|
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
||||||
|
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
updated_params: set[str] = set()
|
||||||
|
|
||||||
|
# Separate and rename weights
|
||||||
|
language_model_weights, other_weights = (
|
||||||
|
self._separate_and_rename_weights(weights))
|
||||||
|
|
||||||
|
# Handle expert scale parameters
|
||||||
|
regular_weights, expert_scale_weights, updated_params_from_experts = (
|
||||||
|
self._handle_expert_scale_broadcasting(language_model_weights,
|
||||||
|
params_dict))
|
||||||
|
updated_params.update(updated_params_from_experts)
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
loaded_language_model_params = loader.load_weights(regular_weights)
|
||||||
|
assert loaded_language_model_params is not None
|
||||||
|
updated_params.update(loaded_language_model_params)
|
||||||
|
|
||||||
|
if expert_scale_weights:
|
||||||
|
loaded_expert_scale_params = loader.load_weights(
|
||||||
|
expert_scale_weights)
|
||||||
|
if loaded_expert_scale_params:
|
||||||
|
updated_params.update(loaded_expert_scale_params)
|
||||||
|
|
||||||
|
updated_params.update(
|
||||||
|
self._load_other_weights(other_weights, params_dict,
|
||||||
|
stacked_params_mapping))
|
||||||
|
|
||||||
return updated_params
|
return updated_params
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user