mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 00:13:20 +08:00
pplx + fp8 test
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
caca0b718a
commit
12ea698498
@ -276,7 +276,7 @@ def batched_moe(
|
|||||||
rank=0,
|
rank=0,
|
||||||
qtype=qtype,
|
qtype=qtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
per_act_token=False),
|
per_act_token=per_act_token),
|
||||||
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
dp_size=1,
|
dp_size=1,
|
||||||
world_size=1,
|
world_size=1,
|
||||||
@ -327,22 +327,13 @@ def torch_moe2(
|
|||||||
tmp2 = SiluAndMul()(tmp1)
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
#tmp1 = ops.cutlass_scaled_mm(a[mask],
|
|
||||||
# w1[i].transpose(0, 1),
|
|
||||||
# a_scale[mask],
|
|
||||||
# w1_scale[i],
|
|
||||||
# torch.bfloat16)
|
|
||||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
w1_scale[i], block_shape,
|
w1_scale[i], block_shape,
|
||||||
torch.bfloat16)
|
torch.bfloat16)
|
||||||
|
|
||||||
tmp2 = SiluAndMul()(tmp1)
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||||
|
|
||||||
# out[mask] = ops.cutlass_scaled_mm(tmp2,
|
|
||||||
# w2[i].transpose(0, 1),
|
|
||||||
# b_scale,
|
|
||||||
# w2_scale[i],
|
|
||||||
# torch.bfloat16)
|
|
||||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
w2_scale[i], block_shape,
|
w2_scale[i], block_shape,
|
||||||
torch.bfloat16)
|
torch.bfloat16)
|
||||||
@ -403,10 +394,10 @@ def test_fused_moe_batched_experts(
|
|||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
|
||||||
w2_s, use_fp8_w8a8, block_shape)
|
|
||||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
w2_s, qtype, block_shape)
|
w2_s, qtype, block_shape)
|
||||||
|
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
|
w2_s, use_fp8_w8a8, block_shape)
|
||||||
|
|
||||||
torch.testing.assert_close(baseline_output,
|
torch.testing.assert_close(baseline_output,
|
||||||
batched_output,
|
batched_output,
|
||||||
|
|||||||
@ -33,7 +33,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
|||||||
get_default_config)
|
get_default_config)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||||
(222, 2048, 1024)]
|
(222, 2048, 1024)]
|
||||||
@ -280,6 +283,70 @@ def batched_moe(
|
|||||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||||
|
|
||||||
|
|
||||||
|
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16):
|
||||||
|
"""This function performs matrix multiplication with block-wise
|
||||||
|
quantization using native torch.
|
||||||
|
It is agnostic to the input data type and can be used for both int8 and
|
||||||
|
fp8 data types.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||||
|
`Bs` (float32).
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
"""
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32).contiguous()
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
|
||||||
|
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
|
||||||
|
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
N, K = B.shape
|
||||||
|
origin_C_shape = A.shape[:-1] + (N, )
|
||||||
|
A = A.reshape(M, A.shape[-1])
|
||||||
|
As = As.reshape(M, As.shape[-1])
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
assert n_tiles == Bs.shape[0]
|
||||||
|
assert k_tiles == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = (M, N)
|
||||||
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
|
A_tiles = [
|
||||||
|
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
B_tiles = [[
|
||||||
|
B[
|
||||||
|
j * block_n:min((j + 1) * block_n, N),
|
||||||
|
i * block_k:min((i + 1) * block_k, K),
|
||||||
|
] for i in range(k_tiles)
|
||||||
|
] for j in range(n_tiles)]
|
||||||
|
C_tiles = [
|
||||||
|
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
a = A_tiles[i]
|
||||||
|
b = B_tiles[j][i]
|
||||||
|
c = C_tiles[j]
|
||||||
|
s = As_tiles[i] * Bs[j][i]
|
||||||
|
c[:, :] += torch.matmul(a, b.t()) * s
|
||||||
|
|
||||||
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
# Note: same as torch_moe but with fused_topk factored out.
|
# Note: same as torch_moe but with fused_topk factored out.
|
||||||
def torch_moe2(
|
def torch_moe2(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -287,17 +354,44 @@ def torch_moe2(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
M, K = a.shape
|
M, K = a.shape
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
|
||||||
|
else:
|
||||||
|
a_scale = None
|
||||||
|
|
||||||
|
out = torch.zeros(M * topk,
|
||||||
|
w2.shape[1],
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=a.device)
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
mask = (topk_ids == i).view(-1)
|
mask = (topk_ids == i).view(-1)
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
out[mask] = SiluAndMul()(
|
if not use_fp8_w8a8:
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
|
else:
|
||||||
|
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||||
|
w1_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
|
tmp2 = SiluAndMul()(tmp1)
|
||||||
|
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||||
|
|
||||||
|
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||||
|
w2_scale[i], block_shape,
|
||||||
|
torch.bfloat16)
|
||||||
|
|
||||||
return (out.view(M, -1, w2.shape[1]) *
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
@ -502,6 +596,10 @@ def pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
use_compile: bool = True,
|
use_compile: bool = True,
|
||||||
use_cudagraphs: bool = True,
|
use_cudagraphs: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -511,9 +609,20 @@ def pplx_moe(
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
hidden_dim = a.shape[1]
|
hidden_dim = a.shape[1]
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
block_size = 128
|
block_size = block_shape[1] if block_shape is not None else 128
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
if block_shape:
|
||||||
|
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0])
|
||||||
|
else:
|
||||||
|
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||||
|
|
||||||
|
if qtype is not None:
|
||||||
|
a_dtype = qtype
|
||||||
|
#print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}")
|
||||||
|
scale_bytes = 16
|
||||||
|
else:
|
||||||
|
a_dtype = a.dtype
|
||||||
|
scale_bytes = 0
|
||||||
|
|
||||||
ata = AllToAll.internode(
|
ata = AllToAll.internode(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
@ -523,10 +632,8 @@ def pplx_moe(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
hidden_dim_bytes=hidden_dim * a_dtype.itemsize,
|
||||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
hidden_dim_scale_bytes=scale_bytes,
|
||||||
((hidden_dim + block_size - 1) // block_size *
|
|
||||||
torch.float32.itemsize)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||||
@ -537,11 +644,15 @@ def pplx_moe(
|
|||||||
world_size,
|
world_size,
|
||||||
rank,
|
rank,
|
||||||
dp_size,
|
dp_size,
|
||||||
|
quant_dtype=qtype,
|
||||||
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size)
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@ -557,7 +668,14 @@ def pplx_moe(
|
|||||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||||
|
|
||||||
if use_compile:
|
if w1_scale is not None:
|
||||||
|
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||||
|
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||||
|
else:
|
||||||
|
w1_scale_chunk = None
|
||||||
|
w2_scale_chunk = None
|
||||||
|
|
||||||
|
if False and use_compile:
|
||||||
_fused_experts = torch.compile(fused_experts,
|
_fused_experts = torch.compile(fused_experts,
|
||||||
backend='inductor',
|
backend='inductor',
|
||||||
fullgraph=True)
|
fullgraph=True)
|
||||||
@ -569,9 +687,11 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
if use_cudagraphs:
|
if False and use_cudagraphs: #XXXXXXXXXXXX
|
||||||
out.fill_(0)
|
out.fill_(0)
|
||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
@ -581,6 +701,8 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -643,6 +765,10 @@ def _pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
score: torch.Tensor,
|
score: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
w1_s: Optional[torch.Tensor] = None,
|
||||||
|
w2_s: Optional[torch.Tensor] = None,
|
||||||
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
uid = nvshmem_get_unique_id(
|
uid = nvshmem_get_unique_id(
|
||||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
@ -654,11 +780,20 @@ def _pplx_moe(
|
|||||||
|
|
||||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||||
|
|
||||||
|
use_fp8_w8a8 = qtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
device = torch.device("cuda", pgi.rank)
|
||||||
|
a = a.to(device)
|
||||||
|
w1 = w1.to(device)
|
||||||
|
w2 = w2.to(device)
|
||||||
|
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||||
|
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||||
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
||||||
topk_weight, topk_ids)
|
topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape)
|
||||||
# TODO (bnell): fix + re-enable
|
# TODO (bnell): fix + re-enable
|
||||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||||
# topk_ids)
|
# topk_ids)
|
||||||
@ -675,7 +810,7 @@ def _pplx_moe(
|
|||||||
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||||
@requires_pplx
|
@requires_pplx
|
||||||
def test_pplx_moe(
|
def test_pplx_moe(
|
||||||
@ -688,9 +823,40 @@ def test_pplx_moe(
|
|||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
m, n, k = mnk
|
m, n, k = mnk
|
||||||
world_size, dp_size = world_dp_size
|
world_size, dp_size = world_dp_size
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
|
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
block_shape = [128, 128]
|
||||||
|
quant_type = torch.float8_e4m3fn
|
||||||
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
n_tiles_w1 = (2 * n + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
w1_s = torch.rand(
|
||||||
|
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
w2_s = torch.rand(
|
||||||
|
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
else:
|
||||||
|
block_shape = None
|
||||||
|
quant_type = None
|
||||||
|
w1_s = None
|
||||||
|
w2_s = None
|
||||||
|
|
||||||
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_type, block_shape)
|
||||||
|
|||||||
@ -457,6 +457,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=a1.device)
|
device=a1.device)
|
||||||
else:
|
else:
|
||||||
|
assert a1_scale is None
|
||||||
b_a1_scale = None
|
b_a1_scale = None
|
||||||
|
|
||||||
first_expert = num_local_experts * self.rank
|
first_expert = num_local_experts * self.rank
|
||||||
@ -782,8 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
(E, num_tokens, N // 2))
|
(E, num_tokens, N // 2))
|
||||||
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
||||||
|
|
||||||
assert not self.use_fp8_w8a8 or a1q_scale is not None
|
|
||||||
|
|
||||||
# MM1
|
# MM1
|
||||||
invoke_moe_batched_triton_kernel(A=hidden_states,
|
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||||
B=w1,
|
B=w1,
|
||||||
|
|||||||
@ -8,6 +8,9 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -238,6 +241,18 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_config_input_activations(
|
||||||
|
quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||||
|
and "Linear" in quant_config.target_scheme_map and
|
||||||
|
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||||
|
return quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -262,12 +277,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
hidden_dim=moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||||
# For blocked per token: set to
|
# For blocked per token: set to
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
# For per-token: set to sizeof(float32)
|
# For per-token: set to sizeof(float32)
|
||||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
|
hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else (
|
||||||
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
((moe.hidden_dim + moe.block_size - 1) // moe.block_size) *
|
||||||
torch.float32.itemsize)),
|
torch.float32.itemsize)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -285,7 +300,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
rank=all2all_manager.rank,
|
rank=all2all_manager.rank,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
quant_dtype=moe.in_dtype,
|
quant_dtype=moe.quant_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
@ -774,6 +789,17 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
|
quant_dtype = vllm_config.model_config.dtype
|
||||||
|
if quant_config is not None:
|
||||||
|
input_activations = get_quant_config_input_activations(
|
||||||
|
quant_config)
|
||||||
|
if (input_activations is not None
|
||||||
|
and input_activations.num_bits == 8):
|
||||||
|
if input_activations.type == QuantizationType.FLOAT:
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
elif input_activations.type == QuantizationType.INT:
|
||||||
|
quant_dtype = torch.int8
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
@ -781,6 +807,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
in_dtype=vllm_config.model_config.dtype,
|
in_dtype=vllm_config.model_config.dtype,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
@ -822,12 +849,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if self.moe_parallel_config.use_pplx_kernels:
|
if self.moe_parallel_config.use_pplx_kernels:
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||||
dtype=moe.in_dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||||
dtype=moe.in_dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user