mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 21:14:09 +08:00
[MoE][Refactor 1/N] Separate Online Quantization (#30627)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
3f175f18a2
commit
d0502b4928
@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||
if self.is_checkpoint_fp8_serialized:
|
||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||
else:
|
||||
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
|
||||
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
assert self.quant_config.is_checkpoint_fp8_serialized
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# if we are doing online quantization, patch the weight
|
||||
# loaded to call `process_weights_after_loading` in a streaming fashion
|
||||
# as soon as the last weight chunk is loaded
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
# create a new holder to prevent modifying behavior of any other
|
||||
# objects which might depend on the old one
|
||||
new_extra_weight_attrs = extra_weight_attrs
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# add a counter to track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
extra_weight_attrs = new_extra_weight_attrs
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant
|
||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv = Parameter(
|
||||
dg_w2_weight_scale_inv, requires_grad=False
|
||||
)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
replace_parameter(
|
||||
layer,
|
||||
"w13_weight_scale",
|
||||
torch.ones(
|
||||
layer.local_num_experts,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
)
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
return result
|
||||
|
||||
|
||||
class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for online FP8 quantization.
|
||||
Supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(quant_config, layer)
|
||||
assert not quant_config.is_checkpoint_fp8_serialized
|
||||
assert quant_config.activation_scheme == "dynamic"
|
||||
assert quant_config.weight_block_size is None
|
||||
assert self.flashinfer_moe_backend is None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# We are doing online quantization, patch the weight loaded
|
||||
# to call `process_weights_after_loading` in a streaming fashion
|
||||
# as soon as the last weight chunk is loaded.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
# create a new holder to prevent modifying behavior of any other
|
||||
# objects which might depend on the old one
|
||||
new_extra_weight_attrs = extra_weight_attrs
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# load the current weight chunk
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
|
||||
# add a counter to track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
layer._loaded_numel += loaded_weight.numel()
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
extra_weight_attrs = new_extra_weight_attrs
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
self.rocm_aiter_moe_enabled = False
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Lazy import to avoid importing triton too early.
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
# Reshuffle weights for AITER if needed.
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
# Rushuffle weights for MARLIN if needed.
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(
|
||||
layer, False, input_dtype=self.marlin_input_dtype
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user