Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-28 23:37:04 +00:00
parent 03b41b6cad
commit 3ca8322b74
4 changed files with 10 additions and 28 deletions

View File

@ -205,13 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
block_shape = [16, 16, 32] # 16 for k if not fp8 block_shape = [16, 16, 32] # 16 for k if not fp8
#print(f"tensors.A {tensors.A.shape}")
#print(f"tensors.B {tensors.B.shape}")
if use_fp8_w8a8: if use_fp8_w8a8:
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
#quant_block_shape = [N, K]
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device) A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device) B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
quant_block_shape = [1, 1] quant_block_shape = [1, 1]

View File

@ -63,6 +63,7 @@ requires_pplx = pytest.mark.skipif(
reason="Requires PPLX kernels", reason="Requires PPLX kernels",
) )
@dataclasses.dataclass @dataclasses.dataclass
class ProcessGroupInfo: class ProcessGroupInfo:
world_size: int world_size: int

View File

@ -10,8 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config) get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, _resize_cache, moe_kernel_quantize_input)
moe_kernel_quantize_input)
@triton.jit @triton.jit
@ -480,8 +479,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.qtype, self.qtype,
self.per_act_token, self.per_act_token,
self.block_shape, self.block_shape,
) ))
)
else: else:
b_a1[idx, :rows, :] = rhs b_a1[idx, :rows, :] = rhs
@ -652,10 +650,8 @@ def batched_moe_kernel_quantize_input(
if num_tokens > 0: if num_tokens > 0:
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
A[e, :num_tokens], A[e, :num_tokens],
A_scale[e, :num_tokens] if A_scale else None, A_scale[e, :num_tokens] if A_scale else None, qtype,
qtype, per_channel_quant, [block_k, block_n])
per_channel_quant,
[block_k, block_n])
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
return A_q, A_q_scale return A_q, A_q_scale
@ -812,16 +808,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
a2_scale, self.qtype, self.per_act_token, self.block_shape)
num_tokens,
E,
N,
expert_num_tokens,
self.qtype,
self.per_act_token,
self.block_shape
)
invoke_moe_batched_triton_kernel(A=qintermediate_cache2, invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2, B=w2,

View File

@ -769,13 +769,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize): def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedPrepareAndFinalize, BatchedTritonExperts)
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.") "Marlin and ROCm AITER are not supported with all2all yet.")