cleanup quantization

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-28 20:19:41 +00:00
parent 909f234faa
commit 468d16654a
5 changed files with 96 additions and 78 deletions

View File

@ -270,8 +270,9 @@ def batched_moe(
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token: bool = False,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) # ?
fused_experts = FusedMoEModularKernel(
@ -279,12 +280,13 @@ def batched_moe(
world_size=1,
dp_size=1,
rank=0,
use_fp8_w8a8=use_fp8_w8a8,
block_shape=block_shape),
qtype=qtype,
block_shape=block_shape,
per_act_token=False),
BatchedTritonExperts(max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=use_fp8_w8a8,
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
block_shape=block_shape))
return fused_experts(a,
@ -360,7 +362,7 @@ def torch_moe2(
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
@ -378,6 +380,7 @@ def test_fused_moe_batched_experts(
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
qtype = dtype if dtype == torch.torch.float8_e4m3fn else None
if use_fp8_w8a8:
block_n, block_k = block_shape[0], block_shape[1]
@ -409,7 +412,7 @@ def test_fused_moe_batched_experts(
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, use_fp8_w8a8, block_shape)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, use_fp8_w8a8, block_shape)
w2_s, qtype, block_shape)
torch.testing.assert_close(baseline_output,
batched_output,

View File

@ -9,9 +9,9 @@ import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
moe_kernel_quantize_input)
@triton.jit
@ -47,6 +47,7 @@ def moe_mmk(
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr):
offs_k = tl.arange(0, BLOCK_K)
if use_w8a16:
@ -325,6 +326,7 @@ def invoke_moe_batched_triton_kernel(
use_int4_w4a16: bool,
config: dict[str, int],
block_shape: Optional[list[int]] = None):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
K = A.size(2)
@ -393,15 +395,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
world_size: int,
dp_size: int,
rank: int,
use_fp8_w8a8: bool = False,
qtype: Optional[torch.dtype] = None,
per_act_token: bool = False,
block_shape: Optional[list[int]] = None):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
self.rank = rank
self.max_num_tokens = max_num_tokens
self.use_fp8_w8a8 = use_fp8_w8a8
self.per_act_token = per_act_token
self.block_shape = block_shape
self.qtype = qtype
def prepare(
self,
@ -445,10 +449,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else a1.dtype,
dtype=self.qtype if self.qtype is not None else a1.dtype,
device=a1.device)
if self.use_fp8_w8a8:
if self.qtype is not None:
k_tiles = (hidden_dim + block_k - 1) // block_k
b_a1_scale = torch.zeros(
(num_local_experts, self.max_num_tokens, k_tiles),
@ -465,10 +469,20 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rows = torch.count_nonzero(topks.flatten())
rhs = a1[:topks.numel()][topks]
idx = expert_id - first_expert
if self.use_fp8_w8a8:
# TODO: use _fp8_quantize
b_a1[idx, :rows, :], b_a1_scale[
idx, :rows] = per_token_group_quant_fp8(rhs, block_k)
if self.qtype is not None:
if a1_scale is not None:
rhs_a1_scale = a1_scale[:topks.numel()][topks]
else:
rhs_a1_scale = None
b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = (
moe_kernel_quantize_input(
rhs,
rhs_a1_scale,
self.qtype,
self.per_act_token,
self.block_shape,
)
)
else:
b_a1[idx, :rows, :] = rhs
@ -524,7 +538,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_m: Optional[int] = None,
):
super().__init__()
#assert block_shape is None
assert block_m is None
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
@ -615,6 +628,42 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
return out
def batched_moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
num_tokens: int,
E: int,
N: int,
expert_num_tokens: torch.Tensor,
qtype: Optional[torch.dtype],
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype is not None:
assert block_shape is not None
A_q = torch.empty_like(A, dtype=qtype)
block_n, block_k = block_shape
n_tiles = ((N // 2) + block_n - 1) // block_n
scale_shape = (E, num_tokens, n_tiles)
A_q_scale = torch.empty(scale_shape,
dtype=torch.float32,
device=A.device)
for e in range(E):
num_tokens = expert_num_tokens[e]
if num_tokens > 0:
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
A[e, :num_tokens],
A_scale[e, :num_tokens] if A_scale else None,
qtype,
per_channel_quant,
[block_k, block_n])
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
return A_q, A_q_scale
else:
return A, A_scale
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A Triton based MoE expert class that operates on expert batched format,
@ -630,6 +679,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token: bool = False,
world_size: int = 1,
dp_size: int = 1,
):
@ -644,6 +694,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
self.per_act_token = per_act_token
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
def workspace_shapes(
self,
@ -731,7 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
@ -761,36 +812,17 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2
# TODO (varun) : support w8a8
#assert not self.use_fp8_w8a8
if self.use_fp8_w8a8:
per_act_token = False
# TODO: reuse?
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
dtype=torch.float8_e4m3fn)
block_n = self.block_shape[0]
n_tiles = ((N // 2) + block_n - 1) // block_n
scale_shape = (E, num_tokens, n_tiles)
a2q_scale = torch.empty(scale_shape,
dtype=torch.float32,
device=hidden_states.device)
for e in range(E):
num_tokens = expert_num_tokens[e]
if num_tokens > 0:
#qintermediate_cache2[e], tmp_scale = _fp8_quantize(
# intermediate_cache2[e],
# a2_scale[e] if a2_scale is not None else None,
# per_act_token, self.block_shape)
qintermediate_cache2[
e, :
num_tokens, :], tmp_scale = per_token_group_quant_fp8(
intermediate_cache2[e, :num_tokens], block_n)
a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale
else:
qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
num_tokens,
E,
N,
expert_num_tokens,
self.qtype,
self.per_act_token,
self.block_shape
)
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,

View File

@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):

View File

@ -192,7 +192,7 @@ class MoEConfig:
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
in_dtype: torch.dtype # The post quantization activation type.
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@ -489,22 +489,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
experts = TritonExperts()
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -827,8 +815,7 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
logger.debug(f"PARAM DTYPE = {params_dtype}")
#assert params_dtype.itemsize == 1
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
moe = MoEConfig(
max_num_tokens=MOE_DP_CHUNK_SIZE,
@ -838,7 +825,6 @@ class FusedMoE(torch.nn.Module):
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe.in_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
self.quant_config = quant_config
@ -877,15 +863,14 @@ class FusedMoE(torch.nn.Module):
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
dtype=moe.in_dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
dtype=moe.in_dtype,
device=torch.cuda.current_device())
@property

View File

@ -782,11 +782,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
qtype=torch.float8_e4m3fn,
block_shape=self.quant_config.weight_block_size,
per_act_token=False, #?
)
else:
logger.debug("TritonOrDeepGemmExperts(fp8)")