mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
434 lines
14 KiB
Python
434 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
import vllm._custom_ops as ops
|
|
from tests.kernels.quant_utils import per_block_cast_to_int8
|
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
|
FLOAT8_E4M3_MAX)
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
FusedMoEModularKernel)
|
|
from vllm.model_executor.layers.fused_moe.utils import (
|
|
moe_kernel_quantize_input)
|
|
from vllm.utils import round_up
|
|
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
|
|
|
|
|
def triton_moe(
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
quant_dtype: Optional[torch.dtype] = None,
|
|
per_act_token_quant=False,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
return fused_experts(a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
per_channel_quant=per_act_token_quant,
|
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
|
block_shape=block_shape)
|
|
|
|
|
|
def batched_moe(
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
quant_dtype: Optional[torch.dtype] = None,
|
|
per_act_token_quant: bool = False,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
max_num_tokens = round_up(a.shape[0], 64)
|
|
|
|
fused_experts = FusedMoEModularKernel(
|
|
BatchedPrepareAndFinalize(max_num_tokens,
|
|
num_dispatchers=1,
|
|
num_local_experts=w1.shape[0],
|
|
rank=0),
|
|
BatchedTritonExperts(
|
|
max_num_tokens=max_num_tokens,
|
|
num_dispatchers=1,
|
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
),
|
|
)
|
|
|
|
return fused_experts(a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale)
|
|
|
|
|
|
def naive_batched_moe(
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
a1_scale: Optional[torch.Tensor] = None,
|
|
a2_scale: Optional[torch.Tensor] = None,
|
|
quant_dtype: Optional[torch.dtype] = None,
|
|
per_act_token_quant: bool = False,
|
|
block_shape: Optional[list[int]] = None,
|
|
) -> torch.Tensor:
|
|
max_num_tokens = round_up(a.shape[0], 64)
|
|
|
|
fused_experts = FusedMoEModularKernel(
|
|
BatchedPrepareAndFinalize(max_num_tokens,
|
|
num_dispatchers=1,
|
|
num_local_experts=w1.shape[0],
|
|
rank=0),
|
|
NaiveBatchedExperts(
|
|
max_num_tokens=max_num_tokens,
|
|
num_dispatchers=1,
|
|
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
),
|
|
)
|
|
|
|
return fused_experts(a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale)
|
|
|
|
|
|
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
|
end: int) -> Optional[torch.Tensor]:
|
|
if scales is not None:
|
|
if scales.numel() == 1:
|
|
return scales
|
|
else:
|
|
return scales[start:end]
|
|
return None
|
|
|
|
|
|
def make_quantized_test_activations(
|
|
E: int,
|
|
m: int,
|
|
k: int,
|
|
in_dtype: torch.dtype,
|
|
quant_dtype: Optional[torch.dtype] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
per_act_token_quant: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
|
a_q = a
|
|
a_scale = None
|
|
|
|
if quant_dtype is not None:
|
|
assert (quant_dtype == torch.float8_e4m3fn
|
|
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
|
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
|
a_scale_l = [None] * E
|
|
for e in range(E):
|
|
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
|
a[e], None, quant_dtype, per_act_token_quant, block_shape)
|
|
a_scale = torch.stack(a_scale_l)
|
|
|
|
if not per_act_token_quant and block_shape is None:
|
|
a_scale = a_scale.view(E, 1, 1)
|
|
|
|
return a, a_q, a_scale
|
|
|
|
|
|
def moe_quantize_weights(
|
|
w: torch.Tensor,
|
|
w_s: Optional[torch.Tensor],
|
|
quant_dtype: Union[torch.dtype, str, None],
|
|
per_token_quant: bool,
|
|
block_shape: Optional[list[int]],
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
|
|
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
|
|
|
|
w_gs = None
|
|
|
|
if block_shape is not None:
|
|
assert not per_token_quant
|
|
if quant_dtype == torch.int8:
|
|
w, w_s = per_block_cast_to_int8(w, block_shape)
|
|
elif quant_dtype == torch.float8_e4m3fn:
|
|
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
|
elif quant_dtype == "nvfp4":
|
|
raise RuntimeError("blocked quantization not supported for nvfp4")
|
|
else:
|
|
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
|
else:
|
|
if quant_dtype == torch.int8:
|
|
w, w_s = ops.scaled_int8_quant(
|
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
|
elif quant_dtype == torch.float8_e4m3fn:
|
|
w, w_s = ops.scaled_fp8_quant(
|
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
|
elif quant_dtype == "nvfp4":
|
|
assert not per_token_quant
|
|
w_amax = torch.abs(w).max().to(torch.float32)
|
|
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
|
|
w, w_s = ops.scaled_fp4_quant(w, w_gs)
|
|
else:
|
|
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
|
|
|
return w, w_s, w_gs
|
|
|
|
|
|
def make_test_weight(
|
|
e: int,
|
|
rows: int,
|
|
cols: int,
|
|
in_dtype: torch.dtype = torch.bfloat16,
|
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
per_act_token_quant: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
|
Optional[torch.Tensor]]:
|
|
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
|
w_gs = None
|
|
|
|
if quant_dtype is not None:
|
|
w_l = [None] * e
|
|
w_s_l = [None] * e
|
|
w_gs_l = [None] * e
|
|
for idx in range(e):
|
|
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
|
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
|
|
|
w = torch.stack(w_l)
|
|
w_s = torch.stack(w_s_l)
|
|
if e > 0 and w_gs_l[0] is not None:
|
|
w_gs = torch.stack(w_gs_l)
|
|
if w_s.ndim == 2:
|
|
assert w_s.shape[-1] == 1
|
|
w_s = w_s.view(-1, 1, 1)
|
|
|
|
if block_shape is not None:
|
|
block_n, block_k = block_shape
|
|
n_tiles = (rows + block_n - 1) // block_n
|
|
k_tiles = (cols + block_k - 1) // block_k
|
|
assert w_s.shape == (e, n_tiles, k_tiles)
|
|
else:
|
|
w = w_16
|
|
w_s = None
|
|
w_gs = None
|
|
|
|
return w_16, w, w_s, w_gs
|
|
|
|
|
|
def make_test_weights(
|
|
e: int,
|
|
n: int,
|
|
k: int,
|
|
in_dtype: torch.dtype = torch.bfloat16,
|
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
|
block_shape: Optional[list[int]] = None,
|
|
per_act_token_quant: bool = False,
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
|
Optional[torch.Tensor]],
|
|
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
|
Optional[torch.Tensor]]]:
|
|
return (
|
|
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
|
per_act_token_quant),
|
|
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
|
per_act_token_quant),
|
|
)
|
|
|
|
|
|
def per_token_cast_to_fp8(
|
|
x: torch.Tensor,
|
|
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.dim() == 2
|
|
m, n = x.shape
|
|
pad_size = (block_size - (n % block_size)) % block_size
|
|
x = torch.nn.functional.pad(x,
|
|
(0, pad_size), value=0) if pad_size > 0 else x
|
|
x_view = x.view(m, -1, block_size)
|
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
|
|
|
|
|
# CustomOp?
|
|
class BaselineMM(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
b: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
):
|
|
super().__init__()
|
|
self.b = b.to(dtype=torch.float32)
|
|
self.out_dtype = out_dtype
|
|
|
|
def forward(
|
|
self,
|
|
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
return torch.mm(a.to(dtype=torch.float32),
|
|
self.b).to(self.out_dtype), None
|
|
|
|
|
|
class TestMLP(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
out_dtype: torch.dtype,
|
|
):
|
|
super().__init__()
|
|
self.gate_up_proj = BaselineMM(w1, out_dtype)
|
|
self.down_proj = BaselineMM(w2, out_dtype)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
x, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def make_naive_shared_experts(
|
|
N: int,
|
|
K: int,
|
|
in_dtype: torch.dtype = torch.bfloat16,
|
|
) -> torch.nn.Module:
|
|
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
|
|
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
|
|
return TestMLP(w1, w2, out_dtype=in_dtype)
|
|
|
|
|
|
class RealMLP(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
hidden_act: str = "silu",
|
|
quant_config=None,
|
|
reduce_results: bool = True,
|
|
prefix: str = "",
|
|
w1_s: Optional[torch.Tensor] = None,
|
|
w2_s: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear, RowParallelLinear)
|
|
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size, [intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.gate_up_proj.register_parameter(
|
|
"weight", torch.nn.Parameter(w1, requires_grad=False))
|
|
self.gate_up_proj.register_parameter(
|
|
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
|
|
self.gate_up_proj.register_parameter(
|
|
"input_scale",
|
|
None) #torch.nn.Parameter(None, requires_grad=False))
|
|
self.down_proj = RowParallelLinear(intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.down_proj.register_parameter(
|
|
"weight", torch.nn.Parameter(w2, requires_grad=False))
|
|
self.down_proj.register_parameter(
|
|
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
|
|
self.down_proj.register_parameter(
|
|
"input_scale",
|
|
None) #torch.nn.Parameter(None, requires_grad=False))
|
|
if hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def make_shared_experts(
|
|
N: int,
|
|
K: int,
|
|
in_dtype: torch.dtype = torch.bfloat16,
|
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
|
) -> torch.nn.Module:
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
|
|
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
|
1,
|
|
N,
|
|
K,
|
|
in_dtype=in_dtype,
|
|
quant_dtype=quant_dtype,
|
|
)
|
|
old_dtype = torch.get_default_dtype()
|
|
try:
|
|
torch.set_default_dtype(in_dtype)
|
|
if quant_dtype == torch.float8_e4m3fn:
|
|
w1 = w1[0].transpose(0, 1)
|
|
w2 = w2[0].transpose(0, 1)
|
|
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
|
|
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
|
|
quant_config = Fp8Config(True)
|
|
else:
|
|
w1 = w1[0]
|
|
w2 = w2[0]
|
|
w1_s = None
|
|
w2_s = None
|
|
quant_config = None
|
|
|
|
return RealMLP(K,
|
|
N,
|
|
w1,
|
|
w2,
|
|
"silu",
|
|
quant_config,
|
|
w1_s=w1_s,
|
|
w2_s=w2_s)
|
|
finally:
|
|
torch.set_default_dtype(old_dtype)
|