fp8 + pplx tests + fixes

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-29 21:25:33 +00:00
parent 12ea698498
commit 922165cba3
5 changed files with 25 additions and 25 deletions

View File

@ -268,7 +268,7 @@ def batched_moe(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token: bool = False, per_act_token: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64) # ? max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel( fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens, BatchedPrepareAndFinalize(max_num_tokens,
world_size=1, world_size=1,
@ -342,9 +342,9 @@ def torch_moe2(
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 512, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])

View File

@ -611,15 +611,12 @@ def pplx_moe(
num_experts = w1.shape[0] num_experts = w1.shape[0]
block_size = block_shape[1] if block_shape is not None else 128 block_size = block_shape[1] if block_shape is not None else 128
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
if block_shape: max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0])
else:
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
if qtype is not None: if qtype is not None:
a_dtype = qtype a_dtype = qtype
#print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}") # This is probably not right
scale_bytes = 16 scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16)
else: else:
a_dtype = a.dtype a_dtype = a.dtype
scale_bytes = 0 scale_bytes = 0

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config) get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input) _resize_cache, moe_kernel_quantize_input)
from vllm.utils import round_up
@triton.jit @triton.jit
@ -336,7 +337,7 @@ def invoke_moe_batched_triton_kernel(
BLOCK_K = config['BLOCK_SIZE_K'] BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling() assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing() or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0) or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}"
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N)) triton.cdiv(B.size(1), BLOCK_N))
@ -666,7 +667,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
max_num_tokens: Optional[int] = None, max_num_tokens: int,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
@ -682,13 +683,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16 self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16 self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI" assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI" assert not use_int4_w4a16, "NYI"
self.world_size = world_size self.world_size = world_size
self.dp_size = dp_size self.dp_size = dp_size
self.per_act_token = per_act_token self.per_act_token = per_act_token
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
self.max_num_tokens = max_num_tokens
def workspace_shapes( def workspace_shapes(
self, self,
@ -701,10 +702,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) -> tuple[int, int, torch.dtype]: ) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2 assert a.dim() == 2
num_dp = self.world_size // self.dp_size num_dp = self.world_size // self.dp_size
max_num_tokens = a.size( workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N)
0) if self.max_num_tokens is None else self.max_num_tokens workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2)
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype) return (workspace13, workspace2, a.dtype)
def apply( def apply(

View File

@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op, cdiv
has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_pplx = importlib.util.find_spec("pplx_kernels") is not None
@ -268,6 +268,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = None prepare_finalize = None
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
if moe.quant_dtype.itemsize == 1:
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
torch.float32.itemsize)
else:
scale_bytes = 0
all_to_all_args = dict( all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens, max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts, num_experts=moe.num_experts,
@ -278,12 +287,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim, hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to hidden_dim_scale_bytes=scale_bytes,
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else (
((moe.hidden_dim + moe.block_size - 1) // moe.block_size) *
torch.float32.itemsize)),
) )
if not all2all_manager.internode: if not all2all_manager.internode:

View File

@ -94,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
float32_size = torch.float32.itemsize float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size else 1) * float32_size
expert_x_scale = torch.empty( expert_x_scale = torch.zeros(
( (
num_experts, num_experts,
expert_x.size(1), expert_x.size(1),