From d52358c1e07768266e3db92e847cd28af87ca4b9 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 26 Aug 2025 07:16:33 -0400 Subject: [PATCH] [Perf] Remove duplicated NVFP4 blockscales to save memory (#23379) Signed-off-by: mgoin --- .../compressed_tensors_moe.py | 20 +++++------ .../schemes/compressed_tensors_w4a4_nvfp4.py | 11 +++--- .../layers/quantization/modelopt.py | 34 ++++++++----------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7bc35cd81ac3f..1ee3478aa4f43 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -246,13 +246,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): return # swizzle weight scales - layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w13_weight_scale), - requires_grad=False) + requires_grad=False) - layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale( + layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale( layer.w2_weight_scale), - requires_grad=False) + requires_grad=False) # w13 w13_input_global_scale = layer.w13_input_global_scale.max( @@ -383,8 +383,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -406,8 +406,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -427,8 +427,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 49d76bbeaa3a1..dedd681f15ded 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -112,13 +112,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) @@ -136,7 +135,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x=x, input_global_scale=layer.input_global_scale, weight=layer.weight_packed, - weight_scale_swizzled=layer.weight_scale_swizzled, + weight_scale_swizzled=layer.weight_scale, weight_global_scale=layer.weight_global_scale) if bias is not None: out = out + bias @@ -149,7 +148,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) mm_args = (x_fp4, layer.weight_packed, x_blockscale, - layer.weight_scale_swizzled, layer.alpha, output_dtype) + layer.weight_scale, layer.alpha, output_dtype) if self.backend == "flashinfer-trtllm": out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 046234057f04a..72864853f7e0c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -907,20 +907,18 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): torch.uint8), epilogue_tile_m).reshape( weight_scale.shape).view(torch.float8_e4m3fn)) - layer.weight_scale_swizzled = Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, + requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale - del layer.weight_scale_swizzled def apply( self, @@ -951,14 +949,14 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): assert (x_fp4.dtype == torch.uint8) assert (layer.weight.dtype == torch.uint8) assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale.dtype == torch.float8_e4m3fn) assert (layer.alpha.dtype == torch.float32) mm_args = ( x_fp4, layer.weight, x_blockscale, - layer.weight_scale_swizzled, + layer.weight_scale, layer.alpha, output_dtype, ) @@ -1320,16 +1318,16 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Weight Blockscale must be represented as FP8-E4M3") w13_blockscale_swizzled = swizzle_blockscale( layer.w13_weight_scale) - layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, - requires_grad=False) + layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, + requires_grad=False) assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Blockscale must be represented as FP8-E4M3") w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, - requires_grad=False) + layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, + requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) @@ -1339,8 +1337,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant - del layer.w13_blockscale_swizzled - del layer.w2_blockscale_swizzled def apply( self, @@ -1474,8 +1470,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (self.allow_flashinfer @@ -1489,8 +1485,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - w1_scale=layer.w13_blockscale_swizzled, - w2_scale=layer.w2_blockscale_swizzled, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, @@ -1510,8 +1506,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w2_blockscale=layer.w2_blockscale_swizzled, + w1_blockscale=layer.w13_weight_scale, + w2_blockscale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant,