Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-28 23:37:04 +00:00
parent 03b41b6cad
commit 3ca8322b74
4 changed files with 10 additions and 28 deletions

View File

@ -205,13 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
block_shape = [16, 16, 32] # 16 for k if not fp8
#print(f"tensors.A {tensors.A.shape}")
#print(f"tensors.B {tensors.B.shape}")
if use_fp8_w8a8:
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
#quant_block_shape = [N, K]
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
quant_block_shape = [1, 1]

View File

@ -63,6 +63,7 @@ requires_pplx = pytest.mark.skipif(
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int

View File

@ -10,8 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
moe_kernel_quantize_input)
_resize_cache, moe_kernel_quantize_input)
@triton.jit
@ -480,8 +479,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.qtype,
self.per_act_token,
self.block_shape,
)
)
))
else:
b_a1[idx, :rows, :] = rhs
@ -652,10 +650,8 @@ def batched_moe_kernel_quantize_input(
if num_tokens > 0:
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
A[e, :num_tokens],
A_scale[e, :num_tokens] if A_scale else None,
qtype,
per_channel_quant,
[block_k, block_n])
A_scale[e, :num_tokens] if A_scale else None, qtype,
per_channel_quant, [block_k, block_n])
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
return A_q, A_q_scale
@ -812,16 +808,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
intermediate_cache1.view(-1, N))
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
num_tokens,
E,
N,
expert_num_tokens,
self.qtype,
self.per_act_token,
self.block_shape
)
intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
self.qtype, self.per_act_token, self.block_shape)
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,

View File

@ -769,13 +769,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts)
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")