mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 22:17:55 +08:00
[Perf] Remove duplicated NVFP4 blockscales to save memory (#23379)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
6ace2f72b0
commit
d52358c1e0
@ -246,13 +246,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# swizzle weight scales
|
# swizzle weight scales
|
||||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
|
||||||
layer.w13_weight_scale),
|
layer.w13_weight_scale),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
|
||||||
layer.w2_weight_scale),
|
layer.w2_weight_scale),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# w13
|
# w13
|
||||||
w13_input_global_scale = layer.w13_input_global_scale.max(
|
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||||
@ -383,8 +383,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
w2_scale=layer.w2_weight_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -406,8 +406,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
w2_scale=layer.w2_weight_scale,
|
||||||
g1_alphas=layer.g1_alphas,
|
g1_alphas=layer.g1_alphas,
|
||||||
g2_alphas=layer.g2_alphas,
|
g2_alphas=layer.g2_alphas,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
@ -427,8 +427,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
w2_fp4=layer.w2_weight,
|
w2_fp4=layer.w2_weight,
|
||||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
w1_blockscale=layer.w13_weight_scale,
|
||||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
w2_blockscale=layer.w2_weight_scale,
|
||||||
g1_alphas=layer.g1_alphas,
|
g1_alphas=layer.g1_alphas,
|
||||||
g2_alphas=layer.g2_alphas,
|
g2_alphas=layer.g2_alphas,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
|||||||
@ -112,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
torch.uint8), epilogue_tile_m).reshape(
|
torch.uint8), epilogue_tile_m).reshape(
|
||||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||||
|
|
||||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
requires_grad=False)
|
|
||||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.weight_packed = Parameter(layer.weight_packed.data,
|
layer.weight_packed = Parameter(layer.weight_packed.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
@ -136,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
x=x,
|
x=x,
|
||||||
input_global_scale=layer.input_global_scale,
|
input_global_scale=layer.input_global_scale,
|
||||||
weight=layer.weight_packed,
|
weight=layer.weight_packed,
|
||||||
weight_scale_swizzled=layer.weight_scale_swizzled,
|
weight_scale_swizzled=layer.weight_scale,
|
||||||
weight_global_scale=layer.weight_global_scale)
|
weight_global_scale=layer.weight_global_scale)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
@ -149,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||||
|
|
||||||
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
|
||||||
layer.weight_scale_swizzled, layer.alpha, output_dtype)
|
layer.weight_scale, layer.alpha, output_dtype)
|
||||||
if self.backend == "flashinfer-trtllm":
|
if self.backend == "flashinfer-trtllm":
|
||||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
|
||||||
elif self.backend == "flashinfer-cutlass":
|
elif self.backend == "flashinfer-cutlass":
|
||||||
|
|||||||
@ -907,20 +907,18 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
torch.uint8), epilogue_tile_m).reshape(
|
torch.uint8), epilogue_tile_m).reshape(
|
||||||
weight_scale.shape).view(torch.float8_e4m3fn))
|
weight_scale.shape).view(torch.float8_e4m3fn))
|
||||||
|
|
||||||
layer.weight_scale_swizzled = Parameter(weight_scale,
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
requires_grad=False)
|
|
||||||
layer.weight = Parameter(weight, requires_grad=False)
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
layer.weight_scale = Parameter(swizzled_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
if self.backend == "marlin":
|
if self.backend == "marlin":
|
||||||
prepare_fp4_layer_for_marlin(layer)
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
del layer.alpha
|
del layer.alpha
|
||||||
del layer.input_scale
|
del layer.input_scale
|
||||||
del layer.weight_scale_swizzled
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -951,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
assert (x_fp4.dtype == torch.uint8)
|
assert (x_fp4.dtype == torch.uint8)
|
||||||
assert (layer.weight.dtype == torch.uint8)
|
assert (layer.weight.dtype == torch.uint8)
|
||||||
assert (x_blockscale.dtype == torch.float8_e4m3fn)
|
assert (x_blockscale.dtype == torch.float8_e4m3fn)
|
||||||
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
|
assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
|
||||||
assert (layer.alpha.dtype == torch.float32)
|
assert (layer.alpha.dtype == torch.float32)
|
||||||
|
|
||||||
mm_args = (
|
mm_args = (
|
||||||
x_fp4,
|
x_fp4,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
x_blockscale,
|
x_blockscale,
|
||||||
layer.weight_scale_swizzled,
|
layer.weight_scale,
|
||||||
layer.alpha,
|
layer.alpha,
|
||||||
output_dtype,
|
output_dtype,
|
||||||
)
|
)
|
||||||
@ -1320,16 +1318,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
w13_blockscale_swizzled = swizzle_blockscale(
|
w13_blockscale_swizzled = swizzle_blockscale(
|
||||||
layer.w13_weight_scale)
|
layer.w13_weight_scale)
|
||||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
@ -1339,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
del layer.g2_alphas
|
del layer.g2_alphas
|
||||||
del layer.w13_input_scale_quant
|
del layer.w13_input_scale_quant
|
||||||
del layer.w2_input_scale_quant
|
del layer.w2_input_scale_quant
|
||||||
del layer.w13_blockscale_swizzled
|
|
||||||
del layer.w2_blockscale_swizzled
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -1474,8 +1470,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
w2_scale=layer.w2_weight_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (self.allow_flashinfer
|
elif (self.allow_flashinfer
|
||||||
@ -1489,8 +1485,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
w2_scale=layer.w2_weight_scale,
|
||||||
g1_alphas=layer.g1_alphas,
|
g1_alphas=layer.g1_alphas,
|
||||||
g2_alphas=layer.g2_alphas,
|
g2_alphas=layer.g2_alphas,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
@ -1510,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
w2_fp4=layer.w2_weight,
|
w2_fp4=layer.w2_weight,
|
||||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
w1_blockscale=layer.w13_weight_scale,
|
||||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
w2_blockscale=layer.w2_weight_scale,
|
||||||
g1_alphas=layer.g1_alphas,
|
g1_alphas=layer.g1_alphas,
|
||||||
g2_alphas=layer.g2_alphas,
|
g2_alphas=layer.g2_alphas,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user