mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3d847a3125
commit
bda9d0535f
@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16, cutlass_fp4_supported)
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
cutlass_fp4_supported)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported, swizzle_blockscale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# From packed to weight
|
||||
@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
return
|
||||
|
||||
# swizzle weight scales
|
||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(
|
||||
self.swizzle_blockscale(layer.w13_weight_scale),
|
||||
requires_grad=False)
|
||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w13_weight_scale),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(
|
||||
self.swizzle_blockscale(layer.w2_weight_scale),
|
||||
requires_grad=False)
|
||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||
layer.w2_weight_scale),
|
||||
requires_grad=False)
|
||||
|
||||
# w13
|
||||
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||
|
||||
@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, is_layer_skipped)
|
||||
GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# global scales:
|
||||
@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
|
||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||
requires_grad=False)
|
||||
@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1
|
||||
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
||||
@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
@ -2,13 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = [
|
||||
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
|
||||
"cutlass_fp4_supported"
|
||||
"break_fp4_bytes",
|
||||
"dequantize_to_dtype",
|
||||
"ref_nvfp4_quant",
|
||||
]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
@ -592,3 +594,56 @@ def awq_pack(
|
||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||
|
||||
|
||||
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||
layout expected by the CUTLASS / FlashInfer kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale: torch.Tensor
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The swizzled tensor with the same logical shape as *scale*.
|
||||
"""
|
||||
assert scale.dtype == torch.float8_e4m3fn, (
|
||||
"swizzle_blockscale expects the input tensor to be in "
|
||||
"torch.float8_e4m3fn format.")
|
||||
|
||||
scale_ndim = scale.ndim
|
||||
if scale_ndim == 2:
|
||||
scale = scale.unsqueeze(0) # (1, M, K)
|
||||
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||
|
||||
B, M, K = scale.shape
|
||||
|
||||
def _round_up(x: int, m: int) -> int:
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
M_padded = _round_up(M, 128)
|
||||
K_padded = _round_up(K, 4)
|
||||
|
||||
padded = torch.zeros((B, M_padded, K_padded),
|
||||
dtype=scale.dtype,
|
||||
device=scale.device)
|
||||
padded[:B, :M, :K] = scale
|
||||
|
||||
# Reshape / permute to the layout required by the kernel.
|
||||
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||
|
||||
if scale_ndim == 2:
|
||||
return swizzled.reshape(M, K)
|
||||
return swizzled.reshape(B, M, K)
|
||||
|
||||
|
||||
def cutlass_fp4_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||
return cutlass_scaled_mm_supports_fp4(capability)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user