mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 23:45:02 +08:00
[Refactor] Refactor MOE NVFP4 Code Base: ModelOpt + Compressed Tensor (#21631)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3d847a3125
commit
bda9d0535f
@ -17,7 +17,9 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||||
CompressedTensorsWNA16, cutlass_fp4_supported)
|
CompressedTensorsWNA16)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
cutlass_fp4_supported)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
sparse_cutlass_supported)
|
sparse_cutlass_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|||||||
find_matched_target, is_activation_quantization_format,
|
find_matched_target, is_activation_quantization_format,
|
||||||
should_ignore_layer)
|
should_ignore_layer)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
cutlass_fp4_supported)
|
cutlass_fp4_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|||||||
@ -27,8 +27,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
prepare_moe_fp4_layer_for_marlin)
|
prepare_moe_fp4_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
cutlass_fp4_supported)
|
cutlass_fp4_supported, swizzle_blockscale)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -193,29 +193,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
def swizzle_blockscale(self, scale: torch.tensor):
|
|
||||||
assert (scale.dtype == torch.float8_e4m3fn)
|
|
||||||
# Pad and blockwise interleave weight_scale
|
|
||||||
scale_ndim = scale.ndim
|
|
||||||
if scale.ndim == 2:
|
|
||||||
scale = scale.unsqueeze(0)
|
|
||||||
assert scale.ndim == 3
|
|
||||||
B, M, K = scale.shape
|
|
||||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
|
||||||
M_padded = round_up_multiple(M, 128)
|
|
||||||
K_padded = round_up_multiple(K, 4)
|
|
||||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
|
||||||
padded_scale[:B, :M, :K] = scale
|
|
||||||
batches, rows, cols = padded_scale.shape
|
|
||||||
assert rows % 128 == 0
|
|
||||||
assert cols % 4 == 0
|
|
||||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
|
||||||
cols // 4, 4)
|
|
||||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
|
||||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
|
||||||
return (swizzled_scale.reshape(M, K)
|
|
||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
# From packed to weight
|
# From packed to weight
|
||||||
@ -243,13 +220,13 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# swizzle weight scales
|
# swizzle weight scales
|
||||||
layer.w13_blockscale_swizzled = torch.nn.Parameter(
|
layer.w13_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||||
self.swizzle_blockscale(layer.w13_weight_scale),
|
layer.w13_weight_scale),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
layer.w2_blockscale_swizzled = torch.nn.Parameter(
|
layer.w2_blockscale_swizzled = torch.nn.Parameter(swizzle_blockscale(
|
||||||
self.swizzle_blockscale(layer.w2_weight_scale),
|
layer.w2_weight_scale),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
# w13
|
# w13
|
||||||
w13_input_global_scale = layer.w13_input_global_scale.max(
|
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||||
|
|||||||
@ -9,8 +9,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||||
@ -28,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape, is_layer_skipped)
|
GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp, requantize_with_max_scale)
|
Fp8LinearOp, requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||||
@ -667,14 +666,6 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp4_supported() -> bool:
|
|
||||||
if not current_platform.is_cuda():
|
|
||||||
return False
|
|
||||||
capability_tuple = current_platform.get_device_capability()
|
|
||||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
|
||||||
return cutlass_scaled_mm_supports_fp4(capability)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
@ -772,29 +763,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
def swizzle_blockscale(self, scale: torch.tensor):
|
|
||||||
assert (scale.dtype == torch.float8_e4m3fn)
|
|
||||||
# Pad and blockwise interleave weight_scale
|
|
||||||
scale_ndim = scale.ndim
|
|
||||||
if scale.ndim == 2:
|
|
||||||
scale = scale.unsqueeze(0)
|
|
||||||
assert scale.ndim == 3
|
|
||||||
B, M, K = scale.shape
|
|
||||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
|
||||||
M_padded = round_up_multiple(M, 128)
|
|
||||||
K_padded = round_up_multiple(K, 4)
|
|
||||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
|
||||||
padded_scale[:B, :M, :K] = scale
|
|
||||||
batches, rows, cols = padded_scale.shape
|
|
||||||
assert rows % 128 == 0
|
|
||||||
assert cols % 4 == 0
|
|
||||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
|
||||||
cols // 4, 4)
|
|
||||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
|
||||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
|
||||||
return (swizzled_scale.reshape(M, K)
|
|
||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
# global scales:
|
# global scales:
|
||||||
@ -814,7 +782,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Block scale must be represented as FP8-E4M3")
|
"Weight Block scale must be represented as FP8-E4M3")
|
||||||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||||
|
|
||||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -1060,29 +1028,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
def swizzle_blockscale(self, scale: torch.tensor):
|
|
||||||
assert (scale.dtype == torch.float8_e4m3fn)
|
|
||||||
# Pad and blockwise interleave weight_scale
|
|
||||||
scale_ndim = scale.ndim
|
|
||||||
if scale.ndim == 2:
|
|
||||||
scale = scale.unsqueeze(0)
|
|
||||||
assert scale.ndim == 3
|
|
||||||
B, M, K = scale.shape
|
|
||||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
|
||||||
M_padded = round_up_multiple(M, 128)
|
|
||||||
K_padded = round_up_multiple(K, 4)
|
|
||||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
|
||||||
padded_scale[:B, :M, :K] = scale
|
|
||||||
batches, rows, cols = padded_scale.shape
|
|
||||||
assert rows % 128 == 0
|
|
||||||
assert cols % 4 == 0
|
|
||||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
|
||||||
cols // 4, 4)
|
|
||||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
|
||||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
|
||||||
return (swizzled_scale.reshape(M, K)
|
|
||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# GEMM 1
|
# GEMM 1
|
||||||
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
||||||
@ -1128,8 +1073,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
w13_blockscale_swizzled = self.swizzle_blockscale(
|
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
|
||||||
layer.w13_weight_scale)
|
|
||||||
|
|
||||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -1151,7 +1095,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Blockscale must be represented as FP8-E4M3")
|
"Weight Blockscale must be represented as FP8-E4M3")
|
||||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||||
|
|
||||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|||||||
@ -2,13 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
|
"break_fp4_bytes",
|
||||||
"cutlass_fp4_supported"
|
"dequantize_to_dtype",
|
||||||
|
"ref_nvfp4_quant",
|
||||||
]
|
]
|
||||||
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
@ -17,14 +16,6 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
|||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp4_supported() -> bool:
|
|
||||||
if not current_platform.is_cuda():
|
|
||||||
return False
|
|
||||||
capability_tuple = current_platform.get_device_capability()
|
|
||||||
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
|
||||||
return cutlass_scaled_mm_supports_fp4(capability)
|
|
||||||
|
|
||||||
|
|
||||||
def break_fp4_bytes(a, dtype):
|
def break_fp4_bytes(a, dtype):
|
||||||
assert a.dtype == torch.uint8
|
assert a.dtype == torch.uint8
|
||||||
m, n = a.shape
|
m, n = a.shape
|
||||||
|
|||||||
@ -8,8 +8,10 @@ from typing import ClassVar, NamedTuple, Optional
|
|||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||||
from vllm.model_executor.layers.quantization.qqq import (
|
from vllm.model_executor.layers.quantization.qqq import (
|
||||||
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
|
||||||
@ -592,3 +594,56 @@ def awq_pack(
|
|||||||
q_w = q_w.reshape((-1, size_n)).contiguous()
|
q_w = q_w.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
return pack_cols(q_w, num_bits, size_k, size_n)
|
return pack_cols(q_w, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
|
||||||
|
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad and block-interleave the FP4 block-scales so that they match the data
|
||||||
|
layout expected by the CUTLASS / FlashInfer kernels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
scale: torch.Tensor
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The swizzled tensor with the same logical shape as *scale*.
|
||||||
|
"""
|
||||||
|
assert scale.dtype == torch.float8_e4m3fn, (
|
||||||
|
"swizzle_blockscale expects the input tensor to be in "
|
||||||
|
"torch.float8_e4m3fn format.")
|
||||||
|
|
||||||
|
scale_ndim = scale.ndim
|
||||||
|
if scale_ndim == 2:
|
||||||
|
scale = scale.unsqueeze(0) # (1, M, K)
|
||||||
|
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
|
||||||
|
|
||||||
|
B, M, K = scale.shape
|
||||||
|
|
||||||
|
def _round_up(x: int, m: int) -> int:
|
||||||
|
return (x + m - 1) // m * m
|
||||||
|
|
||||||
|
M_padded = _round_up(M, 128)
|
||||||
|
K_padded = _round_up(K, 4)
|
||||||
|
|
||||||
|
padded = torch.zeros((B, M_padded, K_padded),
|
||||||
|
dtype=scale.dtype,
|
||||||
|
device=scale.device)
|
||||||
|
padded[:B, :M, :K] = scale
|
||||||
|
|
||||||
|
# Reshape / permute to the layout required by the kernel.
|
||||||
|
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
|
||||||
|
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
|
||||||
|
|
||||||
|
if scale_ndim == 2:
|
||||||
|
return swizzled.reshape(M, K)
|
||||||
|
return swizzled.reshape(B, M, K)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp4_supported() -> bool:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||||
|
return cutlass_scaled_mm_supports_fp4(capability)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user