mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 06:48:42 +08:00
[Perf] Remove duplicated NVFP4 blockscales to save memory (#23379)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6ace2f72b0
commit
d52358c1e0
@ -246,13 +246,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
return
|
||||
|
||||
# swizzle weight scales
|
||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w13_weight_scale),
|
||||
requires_grad=False)
|
||||
requires_grad=False)
|
||||
|
||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w2_weight_scale),
|
||||
requires_grad=False)
|
||||
requires_grad=False)
|
||||
|
||||
# w13
|
||||
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||
@ -383,8 +383,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -406,8 +406,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
@ -427,8 +427,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w1_blockscale=layer.w13_weight_scale,
|
||||
w2_blockscale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
|
||||
@ -112,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||
requires_grad=False)
|
||||
|
||||
@ -136,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale_swizzled,
|
||||
weight_scale_swizzled=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
@ -149,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
||||
layer.weight_scale_swizzled, layer.alpha, output_dtype)
|
||||
layer.weight_scale, layer.alpha, output_dtype)
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||
elif self.backend == "flashinfer-cutlass":
|
||||
|
||||
@ -907,20 +907,18 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
torch.uint8), epilogue_tile_m).reshape(
|
||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
del layer.weight_scale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -951,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
assert (x_fp4.dtype == torch.uint8)
|
||||
assert (layer.weight.dtype == torch.uint8)
|
||||
assert (x_blockscale.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
|
||||
assert (layer.alpha.dtype == torch.float32)
|
||||
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight,
|
||||
x_blockscale,
|
||||
layer.weight_scale_swizzled,
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
@ -1320,16 +1318,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
@ -1339,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
del layer.w13_blockscale_swizzled
|
||||
del layer.w2_blockscale_swizzled
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -1474,8 +1470,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (self.allow_flashinfer
|
||||
@ -1489,8 +1485,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
@ -1510,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w1_blockscale=layer.w13_weight_scale,
|
||||
w2_blockscale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user