mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 06:37:03 +08:00
fp8 + pplx tests + fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
12ea698498
commit
922165cba3
@ -268,7 +268,7 @@ def batched_moe(
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64) # ?
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
@ -342,9 +342,9 @@ def torch_moe2(
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
|
||||
@ -611,15 +611,12 @@ def pplx_moe(
|
||||
num_experts = w1.shape[0]
|
||||
block_size = block_shape[1] if block_shape is not None else 128
|
||||
topk = topk_ids.shape[1]
|
||||
if block_shape:
|
||||
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0])
|
||||
else:
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||
|
||||
if qtype is not None:
|
||||
a_dtype = qtype
|
||||
#print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}")
|
||||
scale_bytes = 16
|
||||
# This is probably not right
|
||||
scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16)
|
||||
else:
|
||||
a_dtype = a.dtype
|
||||
scale_bytes = 0
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -336,7 +337,7 @@ def invoke_moe_batched_triton_kernel(
|
||||
BLOCK_K = config['BLOCK_SIZE_K']
|
||||
assert (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()
|
||||
or max_num_tokens % BLOCK_M == 0)
|
||||
or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}"
|
||||
|
||||
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
||||
triton.cdiv(B.size(1), BLOCK_N))
|
||||
@ -666,7 +667,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
max_num_tokens: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@ -682,13 +683,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.max_num_tokens = max_num_tokens
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.per_act_token = per_act_token
|
||||
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -701,10 +702,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||
workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N)
|
||||
workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2)
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import direct_register_custom_op, cdiv
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
|
||||
@ -268,6 +268,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
prepare_finalize = None
|
||||
if moe.use_pplx_kernels:
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
if moe.quant_dtype.itemsize == 1:
|
||||
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
||||
torch.float32.itemsize)
|
||||
else:
|
||||
scale_bytes = 0
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
@ -278,12 +287,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else (
|
||||
((moe.hidden_dim + moe.block_size - 1) // moe.block_size) *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_scale_bytes=scale_bytes,
|
||||
)
|
||||
|
||||
if not all2all_manager.internode:
|
||||
|
||||
@ -94,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
float32_size = torch.float32.itemsize
|
||||
block_size = (self.block_shape[0] if self.block_shape is not None
|
||||
else 1) * float32_size
|
||||
expert_x_scale = torch.empty(
|
||||
expert_x_scale = torch.zeros(
|
||||
(
|
||||
num_experts,
|
||||
expert_x.size(1),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user