mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 05:31:51 +08:00
[BugFix] Fix Torch.Compile For DeepSeek (#12594)
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
parent
e3f7ff65e7
commit
325f679f32
@ -245,20 +245,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
# TODO(rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, _ = \
|
||||
weight, weight_scale_inv, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
weight_scale=layer.weight_scale_inv)
|
||||
else:
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||
@ -507,8 +511,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
if current_platform.is_rocm():
|
||||
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@ -518,22 +523,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
else:
|
||||
w13_weight = layer.w13_weight.data
|
||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||
w2_weight = layer.w2_weight
|
||||
w2_weight_scale_inv = layer.w2_weight_scale_inv
|
||||
|
||||
# torch.compile() cannot use Parameter subclasses.
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user