pplx + fp8 test

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-29 18:50:37 +00:00
parent caca0b718a
commit 12ea698498
4 changed files with 225 additions and 42 deletions

View File

@ -276,7 +276,7 @@ def batched_moe(
rank=0,
qtype=qtype,
block_shape=block_shape,
per_act_token=False),
per_act_token=per_act_token),
BatchedTritonExperts(max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
@ -327,22 +327,13 @@ def torch_moe2(
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
else:
#tmp1 = ops.cutlass_scaled_mm(a[mask],
# w1[i].transpose(0, 1),
# a_scale[mask],
# w1_scale[i],
# torch.bfloat16)
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
torch.bfloat16)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
# out[mask] = ops.cutlass_scaled_mm(tmp2,
# w2[i].transpose(0, 1),
# b_scale,
# w2_scale[i],
# torch.bfloat16)
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
torch.bfloat16)
@ -403,10 +394,10 @@ def test_fused_moe_batched_experts(
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, use_fp8_w8a8, block_shape)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, qtype, block_shape)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, use_fp8_w8a8, block_shape)
torch.testing.assert_close(baseline_output,
batched_output,

View File

@ -33,7 +33,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import round_up
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
@ -280,6 +283,70 @@ def batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
def native_w8a8_block_matmul(A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size,
output_dtype=torch.bfloat16):
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
fp8 data types.
It takes two input tensors `A` and `B` (int8) with scales `As` and
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32).contiguous()
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
@ -287,17 +354,44 @@ def torch_moe2(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
if use_fp8_w8a8:
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
else:
a_scale = None
out = torch.zeros(M * topk,
w2.shape[1],
dtype=torch.bfloat16,
device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
if not use_fp8_w8a8:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
else:
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
torch.bfloat16)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
torch.bfloat16)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@ -502,6 +596,10 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_compile: bool = True,
use_cudagraphs: bool = True,
) -> torch.Tensor:
@ -511,9 +609,20 @@ def pplx_moe(
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
block_size = block_shape[1] if block_shape is not None else 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
if block_shape:
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0])
else:
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
if qtype is not None:
a_dtype = qtype
#print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}")
scale_bytes = 16
else:
a_dtype = a.dtype
scale_bytes = 0
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
@ -523,10 +632,8 @@ def pplx_moe(
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_bytes=hidden_dim * a_dtype.itemsize,
hidden_dim_scale_bytes=scale_bytes,
)
topk_ids = topk_ids.to(dtype=torch.uint32)
@ -537,11 +644,15 @@ def pplx_moe(
world_size,
rank,
dp_size,
quant_dtype=qtype,
block_shape=block_shape,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size)
dp_size=dp_size,
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
block_shape=block_shape)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -557,7 +668,14 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if use_compile:
if w1_scale is not None:
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
else:
w1_scale_chunk = None
w2_scale_chunk = None
if False and use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
fullgraph=True)
@ -569,9 +687,11 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
if False and use_cudagraphs: #XXXXXXXXXXXX
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
@ -581,6 +701,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@ -643,6 +765,10 @@ def _pplx_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
@ -654,11 +780,20 @@ def _pplx_moe(
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
use_fp8_w8a8 = qtype == torch.float8_e4m3fn
device = torch.device("cuda", pgi.rank)
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
@ -675,7 +810,7 @@ def _pplx_moe(
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_moe(
@ -688,9 +823,40 @@ def test_pplx_moe(
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if use_fp8_w8a8:
block_shape = [128, 128]
quant_type = torch.float8_e4m3fn
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (n + block_k - 1) // block_k
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
factor_for_scale = 1e-2
w1_s = torch.rand(
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
device="cuda") * factor_for_scale
w2_s = torch.rand(
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
device="cuda") * factor_for_scale
else:
block_shape = None
quant_type = None
w1_s = None
w2_s = None
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_type, block_shape)

View File

@ -457,6 +457,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dtype=torch.float32,
device=a1.device)
else:
assert a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
@ -782,8 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(E, num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
assert not self.use_fp8_w8a8 or a1q_scale is not None
# MM1
invoke_moe_batched_triton_kernel(A=hidden_states,
B=w1,

View File

@ -8,6 +8,9 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@ -238,6 +241,18 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
@ -262,12 +277,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else (
((moe.hidden_dim + moe.block_size - 1) // moe.block_size) *
torch.float32.itemsize)),
)
@ -285,7 +300,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
quant_dtype=moe.quant_dtype,
)
if prepare_finalize is not None:
@ -774,6 +789,17 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
quant_dtype = vllm_config.model_config.dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8):
if input_activations.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_activations.type == QuantizationType.INT:
quant_dtype = torch.int8
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
@ -781,6 +807,7 @@ class FusedMoE(torch.nn.Module):
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=vllm_config.model_config.dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
@ -822,12 +849,12 @@ class FusedMoE(torch.nn.Module):
if self.moe_parallel_config.use_pplx_kernels:
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=moe.in_dtype,
dtype=vllm_config.model_config.dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=moe.in_dtype,
dtype=vllm_config.model_config.dtype,
device=torch.cuda.current_device())
@property