[Perf] Remove duplicated NVFP4 blockscales to save memory (#23379)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-26 07:16:33 -04:00 committed by GitHub
parent 6ace2f72b0
commit d52358c1e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 35 deletions

View File

@ -246,13 +246,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale),
requires_grad=False)
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale),
requires_grad=False)
requires_grad=False)
# w13
w13_input_global_scale = layer.w13_input_global_scale.max(
@ -383,8 +383,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -406,8 +406,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
@ -427,8 +427,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,

View File

@ -112,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
@ -136,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale_swizzled,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale)
if bias is not None:
out = out + bias
@ -149,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (x_fp4, layer.weight_packed, x_blockscale,
layer.weight_scale_swizzled, layer.alpha, output_dtype)
layer.weight_scale, layer.alpha, output_dtype)
if self.backend == "flashinfer-trtllm":
out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm")
elif self.backend == "flashinfer-cutlass":

View File

@ -907,20 +907,18 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
torch.uint8), epilogue_tile_m).reshape(
weight_scale.shape).view(torch.float8_e4m3fn))
layer.weight_scale_swizzled = Parameter(weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight_scale = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.backend == "marlin":
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled
def apply(
self,
@ -951,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)
mm_args = (
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale_swizzled,
layer.weight_scale,
layer.alpha,
output_dtype,
)
@ -1320,16 +1318,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = swizzle_blockscale(
layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
layer.w13_weight_scale = Parameter(w13_blockscale_swizzled,
requires_grad=False)
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight_scale = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
@ -1339,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
del layer.w13_blockscale_swizzled
del layer.w2_blockscale_swizzled
def apply(
self,
@ -1474,8 +1470,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (self.allow_flashinfer
@ -1489,8 +1485,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
@ -1510,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w1_blockscale=layer.w13_weight_scale,
w2_blockscale=layer.w2_weight_scale,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,