mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 04:17:51 +08:00
[MoE][Refactor 1/N] Separate Online Quantization (#30627)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
3f175f18a2
commit
d0502b4928
@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig):
|
|||||||
fused_mapping=self.packed_modules_mapping,
|
fused_mapping=self.packed_modules_mapping,
|
||||||
):
|
):
|
||||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||||
moe_quant_method = Fp8MoEMethod(self, layer)
|
if self.is_checkpoint_fp8_serialized:
|
||||||
|
moe_quant_method = Fp8MoEMethod(self, layer)
|
||||||
|
else:
|
||||||
|
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
|
||||||
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||||
return moe_quant_method
|
return moe_quant_method
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.orig_dtype = params_dtype
|
layer.orig_dtype = params_dtype
|
||||||
layer.weight_block_size = None
|
layer.weight_block_size = None
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
assert self.quant_config.is_checkpoint_fp8_serialized
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
layer.weight_block_size = self.weight_block_size
|
layer.weight_block_size = self.weight_block_size
|
||||||
@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
f"weight quantization block_k = {block_k}."
|
f"weight quantization block_k = {block_k}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# if we are doing online quantization, patch the weight
|
|
||||||
# loaded to call `process_weights_after_loading` in a streaming fashion
|
|
||||||
# as soon as the last weight chunk is loaded
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
weight_loader = extra_weight_attrs["weight_loader"]
|
|
||||||
# create a new holder to prevent modifying behavior of any other
|
|
||||||
# objects which might depend on the old one
|
|
||||||
new_extra_weight_attrs = extra_weight_attrs
|
|
||||||
|
|
||||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
|
||||||
# load the current weight chunk
|
|
||||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
|
||||||
|
|
||||||
# add a counter to track how many elements we have updated
|
|
||||||
if not hasattr(layer, "_loaded_numel"):
|
|
||||||
layer._loaded_numel = 0
|
|
||||||
layer._loaded_numel += loaded_weight.numel()
|
|
||||||
|
|
||||||
# if we have loaded all of the elements, call
|
|
||||||
# process_weights_after_loading
|
|
||||||
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
|
||||||
if layer._loaded_numel == target_loaded_numel:
|
|
||||||
self.process_weights_after_loading(layer)
|
|
||||||
|
|
||||||
# Delete the bookkeeping
|
|
||||||
del layer._loaded_numel
|
|
||||||
# Prevent the usual `process_weights_after_loading` call
|
|
||||||
# from doing anything
|
|
||||||
layer._already_called_process_weights_after_loading = True
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
|
||||||
extra_weight_attrs = new_extra_weight_attrs
|
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
if self.block_quant
|
if self.block_quant
|
||||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
)
|
)
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
# process_weights_after_loading()
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
# INPUT_SCALES
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
|
|
||||||
w13_input_scale = torch.nn.Parameter(
|
w13_input_scale = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale_inv = Parameter(
|
layer.w2_weight_scale_inv = Parameter(
|
||||||
dg_w2_weight_scale_inv, requires_grad=False
|
dg_w2_weight_scale_inv, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
|
||||||
elif not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
fp8_dtype = current_platform.fp8_dtype()
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
|
||||||
# merged w13 weights and generate a single scaling factor.
|
|
||||||
replace_parameter(
|
|
||||||
layer,
|
|
||||||
"w13_weight_scale",
|
|
||||||
torch.ones(
|
|
||||||
layer.local_num_experts,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=w13_weight.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for expert in range(layer.local_num_experts):
|
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
||||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
||||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
replace_parameter(layer, "w13_weight", w13_weight)
|
|
||||||
replace_parameter(layer, "w2_weight", w2_weight)
|
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
|
||||||
# reshaping weights is required for aiter moe kernel.
|
|
||||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
|
||||||
layer.w13_weight, layer.w2_weight
|
|
||||||
)
|
|
||||||
|
|
||||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
|
||||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
|
||||||
# MoE kernels require single activation scale and single weight
|
|
||||||
# scale for w13 per expert.
|
|
||||||
else:
|
else:
|
||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||||
|
"""MoE method for online FP8 quantization.
|
||||||
|
Supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||||
|
super().__init__(quant_config, layer)
|
||||||
|
assert not quant_config.is_checkpoint_fp8_serialized
|
||||||
|
assert quant_config.activation_scheme == "dynamic"
|
||||||
|
assert quant_config.weight_block_size is None
|
||||||
|
assert self.flashinfer_moe_backend is None
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
|
layer.hidden_size = hidden_size
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
|
||||||
|
# We are doing online quantization, patch the weight loaded
|
||||||
|
# to call `process_weights_after_loading` in a streaming fashion
|
||||||
|
# as soon as the last weight chunk is loaded.
|
||||||
|
weight_loader = extra_weight_attrs["weight_loader"]
|
||||||
|
# create a new holder to prevent modifying behavior of any other
|
||||||
|
# objects which might depend on the old one
|
||||||
|
new_extra_weight_attrs = extra_weight_attrs
|
||||||
|
|
||||||
|
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||||
|
# load the current weight chunk
|
||||||
|
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||||
|
|
||||||
|
# add a counter to track how many elements we have updated
|
||||||
|
if not hasattr(layer, "_loaded_numel"):
|
||||||
|
layer._loaded_numel = 0
|
||||||
|
layer._loaded_numel += loaded_weight.numel()
|
||||||
|
|
||||||
|
# if we have loaded all of the elements, call
|
||||||
|
# process_weights_after_loading
|
||||||
|
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||||
|
if layer._loaded_numel == target_loaded_numel:
|
||||||
|
self.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# Delete the bookkeeping
|
||||||
|
del layer._loaded_numel
|
||||||
|
# Prevent the usual `process_weights_after_loading` call
|
||||||
|
# from doing anything
|
||||||
|
layer._already_called_process_weights_after_loading = True
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||||
|
extra_weight_attrs = new_extra_weight_attrs
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = False
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Lazy import to avoid importing triton too early.
|
||||||
|
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize in place.
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
for expert in range(layer.local_num_experts):
|
||||||
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
|
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
|
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w13_weight", w13_weight)
|
||||||
|
replace_parameter(layer, "w2_weight", w2_weight)
|
||||||
|
|
||||||
|
# Reshuffle weights for AITER if needed.
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||||
|
layer.w13_weight, layer.w2_weight
|
||||||
|
)
|
||||||
|
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||||
|
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||||
|
|
||||||
|
# Rushuffle weights for MARLIN if needed.
|
||||||
|
if self.use_marlin:
|
||||||
|
prepare_moe_fp8_layer_for_marlin(
|
||||||
|
layer, False, input_dtype=self.marlin_input_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user