mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 10:27:04 +08:00
lint
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
03b41b6cad
commit
3ca8322b74
@ -205,13 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||
|
||||
#print(f"tensors.A {tensors.A.shape}")
|
||||
#print(f"tensors.B {tensors.B.shape}")
|
||||
|
||||
if use_fp8_w8a8:
|
||||
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
|
||||
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||
#quant_block_shape = [N, K]
|
||||
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||
quant_block_shape = [1, 1]
|
||||
|
||||
@ -63,6 +63,7 @@ requires_pplx = pytest.mark.skipif(
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
|
||||
@ -10,8 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
moe_kernel_quantize_input)
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -480,8 +479,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.qtype,
|
||||
self.per_act_token,
|
||||
self.block_shape,
|
||||
)
|
||||
)
|
||||
))
|
||||
else:
|
||||
b_a1[idx, :rows, :] = rhs
|
||||
|
||||
@ -652,10 +650,8 @@ def batched_moe_kernel_quantize_input(
|
||||
if num_tokens > 0:
|
||||
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
|
||||
A[e, :num_tokens],
|
||||
A_scale[e, :num_tokens] if A_scale else None,
|
||||
qtype,
|
||||
per_channel_quant,
|
||||
[block_k, block_n])
|
||||
A_scale[e, :num_tokens] if A_scale else None, qtype,
|
||||
per_channel_quant, [block_k, block_n])
|
||||
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
||||
|
||||
return A_q, A_q_scale
|
||||
@ -812,16 +808,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||
intermediate_cache2,
|
||||
a2_scale,
|
||||
num_tokens,
|
||||
E,
|
||||
N,
|
||||
expert_num_tokens,
|
||||
self.qtype,
|
||||
self.per_act_token,
|
||||
self.block_shape
|
||||
)
|
||||
intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
|
||||
self.qtype, self.per_act_token, self.block_shape)
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||
B=w2,
|
||||
|
||||
@ -769,13 +769,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w2_input_scale
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize):
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user