fp8 + pplx tests + fixes

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-29 21:25:33 +00:00
parent 12ea698498
commit 922165cba3
5 changed files with 25 additions and 25 deletions

View File

@ -268,7 +268,7 @@ def batched_moe(
block_shape: Optional[list[int]] = None,
per_act_token: bool = False,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) # ?
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
@ -342,9 +342,9 @@ def torch_moe2(
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])

View File

@ -611,15 +611,12 @@ def pplx_moe(
num_experts = w1.shape[0]
block_size = block_shape[1] if block_shape is not None else 128
topk = topk_ids.shape[1]
if block_shape:
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0])
else:
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
if qtype is not None:
a_dtype = qtype
#print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}")
scale_bytes = 16
# This is probably not right
scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16)
else:
a_dtype = a.dtype
scale_bytes = 0

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.utils import round_up
@triton.jit
@ -336,7 +337,7 @@ def invoke_moe_batched_triton_kernel(
BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}"
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N))
@ -666,7 +667,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_num_tokens: Optional[int] = None,
max_num_tokens: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@ -682,13 +683,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
self.per_act_token = per_act_token
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
self.max_num_tokens = max_num_tokens
def workspace_shapes(
self,
@ -701,10 +702,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
def apply(

View File

@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
from vllm.utils import direct_register_custom_op, cdiv
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
@ -268,6 +268,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = None
if moe.use_pplx_kernels:
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
if moe.quant_dtype.itemsize == 1:
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
torch.float32.itemsize)
else:
scale_bytes = 0
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
@ -278,12 +287,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else (
((moe.hidden_dim + moe.block_size - 1) // moe.block_size) *
torch.float32.itemsize)),
hidden_dim_scale_bytes=scale_bytes,
)
if not all2all_manager.internode:

View File

@ -94,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size
expert_x_scale = torch.empty(
expert_x_scale = torch.zeros(
(
num_experts,
expert_x.size(1),