[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-27 08:25:21 -04:00 committed by GitHub
parent 3d847a3125
commit bda9d0535f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 75 additions and 106 deletions

View File

@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16, cutlass_fp4_supported)
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform

View File

@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.platforms import current_platform

View File

@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported, swizzle_blockscale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w2_input_scale, extra_weight_attrs)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# From packed to weight
@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return
# swizzle weight scales
layer.w13_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w13_weight_scale),
requires_grad=False)
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w13_weight_scale),
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(
self.swizzle_blockscale(layer.w2_weight_scale),
requires_grad=False)
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
layer.w2_weight_scale),
requires_grad=False)
# w13
w13_input_global_scale = layer.w13_input_global_scale.max(

View File

@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.distributed import get_ep_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, is_layer_skipped)
GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig):
return None
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scale", weight_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: Module) -> None:
# global scales:
@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
weight_loader=weight_loader)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w13_blockscale_swizzled = self.swizzle_blockscale(
layer.w13_weight_scale)
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
requires_grad=False)
@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Blockscale must be represented as FP8-E4M3")
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)

View File

@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
__all__ = [
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
"cutlass_fp4_supported"
"break_fp4_bytes",
"dequantize_to_dtype",
"ref_nvfp4_quant",
]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape

View File

@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@ -592,3 +594,56 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format.")
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros((B, M_padded, K_padded),
dtype=scale.dtype,
device=scale.device)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)