mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-28 23:47:10 +08:00
cleanup quantization
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
909f234faa
commit
468d16654a
@ -270,8 +270,9 @@ def batched_moe(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
use_fp8_w8a8: bool = False,
|
qtype: Optional[torch.dtype] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
max_num_tokens = round_up(a.shape[0], 64) # ?
|
max_num_tokens = round_up(a.shape[0], 64) # ?
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
@ -279,12 +280,13 @@ def batched_moe(
|
|||||||
world_size=1,
|
world_size=1,
|
||||||
dp_size=1,
|
dp_size=1,
|
||||||
rank=0,
|
rank=0,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
qtype=qtype,
|
||||||
block_shape=block_shape),
|
block_shape=block_shape,
|
||||||
|
per_act_token=False),
|
||||||
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||||
dp_size=1,
|
dp_size=1,
|
||||||
world_size=1,
|
world_size=1,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||||
block_shape=block_shape))
|
block_shape=block_shape))
|
||||||
|
|
||||||
return fused_experts(a,
|
return fused_experts(a,
|
||||||
@ -360,7 +362,7 @@ def torch_moe2(
|
|||||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||||
def test_fused_moe_batched_experts(
|
def test_fused_moe_batched_experts(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -378,6 +380,7 @@ def test_fused_moe_batched_experts(
|
|||||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||||
|
qtype = dtype if dtype == torch.torch.float8_e4m3fn else None
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
@ -409,7 +412,7 @@ def test_fused_moe_batched_experts(
|
|||||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
w2_s, use_fp8_w8a8, block_shape)
|
w2_s, use_fp8_w8a8, block_shape)
|
||||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||||
w2_s, use_fp8_w8a8, block_shape)
|
w2_s, qtype, block_shape)
|
||||||
|
|
||||||
torch.testing.assert_close(baseline_output,
|
torch.testing.assert_close(baseline_output,
|
||||||
batched_output,
|
batched_output,
|
||||||
|
|||||||
@ -9,9 +9,9 @@ import triton.language as tl
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
get_config_dtype_str, try_get_optimal_moe_config)
|
get_config_dtype_str, try_get_optimal_moe_config)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
_resize_cache,
|
||||||
per_token_group_quant_fp8)
|
moe_kernel_quantize_input)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -47,6 +47,7 @@ def moe_mmk(
|
|||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
use_w8a8: tl.constexpr,
|
use_w8a8: tl.constexpr,
|
||||||
use_w8a16: tl.constexpr):
|
use_w8a16: tl.constexpr):
|
||||||
|
|
||||||
offs_k = tl.arange(0, BLOCK_K)
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
if use_w8a16:
|
if use_w8a16:
|
||||||
@ -325,6 +326,7 @@ def invoke_moe_batched_triton_kernel(
|
|||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
config: dict[str, int],
|
config: dict[str, int],
|
||||||
block_shape: Optional[list[int]] = None):
|
block_shape: Optional[list[int]] = None):
|
||||||
|
|
||||||
assert not use_int4_w4a16
|
assert not use_int4_w4a16
|
||||||
max_num_tokens = A.size(1)
|
max_num_tokens = A.size(1)
|
||||||
K = A.size(2)
|
K = A.size(2)
|
||||||
@ -393,15 +395,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
world_size: int,
|
world_size: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
use_fp8_w8a8: bool = False,
|
qtype: Optional[torch.dtype] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
block_shape: Optional[list[int]] = None):
|
block_shape: Optional[list[int]] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
self.per_act_token = per_act_token
|
||||||
self.block_shape = block_shape
|
self.block_shape = block_shape
|
||||||
|
self.qtype = qtype
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
@ -445,10 +449,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
b_a1 = torch.zeros(
|
b_a1 = torch.zeros(
|
||||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||||
dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else a1.dtype,
|
dtype=self.qtype if self.qtype is not None else a1.dtype,
|
||||||
device=a1.device)
|
device=a1.device)
|
||||||
|
|
||||||
if self.use_fp8_w8a8:
|
if self.qtype is not None:
|
||||||
k_tiles = (hidden_dim + block_k - 1) // block_k
|
k_tiles = (hidden_dim + block_k - 1) // block_k
|
||||||
b_a1_scale = torch.zeros(
|
b_a1_scale = torch.zeros(
|
||||||
(num_local_experts, self.max_num_tokens, k_tiles),
|
(num_local_experts, self.max_num_tokens, k_tiles),
|
||||||
@ -465,10 +469,20 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
rows = torch.count_nonzero(topks.flatten())
|
rows = torch.count_nonzero(topks.flatten())
|
||||||
rhs = a1[:topks.numel()][topks]
|
rhs = a1[:topks.numel()][topks]
|
||||||
idx = expert_id - first_expert
|
idx = expert_id - first_expert
|
||||||
if self.use_fp8_w8a8:
|
if self.qtype is not None:
|
||||||
# TODO: use _fp8_quantize
|
if a1_scale is not None:
|
||||||
b_a1[idx, :rows, :], b_a1_scale[
|
rhs_a1_scale = a1_scale[:topks.numel()][topks]
|
||||||
idx, :rows] = per_token_group_quant_fp8(rhs, block_k)
|
else:
|
||||||
|
rhs_a1_scale = None
|
||||||
|
b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = (
|
||||||
|
moe_kernel_quantize_input(
|
||||||
|
rhs,
|
||||||
|
rhs_a1_scale,
|
||||||
|
self.qtype,
|
||||||
|
self.per_act_token,
|
||||||
|
self.block_shape,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
b_a1[idx, :rows, :] = rhs
|
b_a1[idx, :rows, :] = rhs
|
||||||
|
|
||||||
@ -524,7 +538,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
block_m: Optional[int] = None,
|
block_m: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
#assert block_shape is None
|
|
||||||
assert block_m is None
|
assert block_m is None
|
||||||
assert not use_int8_w8a8, "NYI"
|
assert not use_int8_w8a8, "NYI"
|
||||||
assert not use_int8_w8a16, "NYI"
|
assert not use_int8_w8a16, "NYI"
|
||||||
@ -615,6 +628,42 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def batched_moe_kernel_quantize_input(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
num_tokens: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
expert_num_tokens: torch.Tensor,
|
||||||
|
qtype: Optional[torch.dtype],
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if qtype is not None:
|
||||||
|
assert block_shape is not None
|
||||||
|
A_q = torch.empty_like(A, dtype=qtype)
|
||||||
|
block_n, block_k = block_shape
|
||||||
|
n_tiles = ((N // 2) + block_n - 1) // block_n
|
||||||
|
scale_shape = (E, num_tokens, n_tiles)
|
||||||
|
A_q_scale = torch.empty(scale_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=A.device)
|
||||||
|
for e in range(E):
|
||||||
|
num_tokens = expert_num_tokens[e]
|
||||||
|
if num_tokens > 0:
|
||||||
|
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
|
||||||
|
A[e, :num_tokens],
|
||||||
|
A_scale[e, :num_tokens] if A_scale else None,
|
||||||
|
qtype,
|
||||||
|
per_channel_quant,
|
||||||
|
[block_k, block_n])
|
||||||
|
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
||||||
|
|
||||||
|
return A_q, A_q_scale
|
||||||
|
else:
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
"""
|
"""
|
||||||
A Triton based MoE expert class that operates on expert batched format,
|
A Triton based MoE expert class that operates on expert batched format,
|
||||||
@ -630,6 +679,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token: bool = False,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
dp_size: int = 1,
|
dp_size: int = 1,
|
||||||
):
|
):
|
||||||
@ -644,6 +694,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert not use_int4_w4a16, "NYI"
|
assert not use_int4_w4a16, "NYI"
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
|
self.per_act_token = per_act_token
|
||||||
|
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -731,7 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
|
||||||
# We can reuse the memory between these because by the time we need
|
# We can reuse the memory between these because by the time we need
|
||||||
# cache3, we're done with cache1
|
# cache3, we're done with cache1
|
||||||
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
||||||
@ -761,36 +812,17 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
#qintermediate_cache2 = intermediate_cache2
|
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||||
|
intermediate_cache2,
|
||||||
# TODO (varun) : support w8a8
|
a2_scale,
|
||||||
#assert not self.use_fp8_w8a8
|
num_tokens,
|
||||||
if self.use_fp8_w8a8:
|
E,
|
||||||
per_act_token = False
|
N,
|
||||||
# TODO: reuse?
|
expert_num_tokens,
|
||||||
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
|
self.qtype,
|
||||||
dtype=torch.float8_e4m3fn)
|
self.per_act_token,
|
||||||
block_n = self.block_shape[0]
|
self.block_shape
|
||||||
n_tiles = ((N // 2) + block_n - 1) // block_n
|
)
|
||||||
scale_shape = (E, num_tokens, n_tiles)
|
|
||||||
a2q_scale = torch.empty(scale_shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states.device)
|
|
||||||
for e in range(E):
|
|
||||||
num_tokens = expert_num_tokens[e]
|
|
||||||
if num_tokens > 0:
|
|
||||||
#qintermediate_cache2[e], tmp_scale = _fp8_quantize(
|
|
||||||
# intermediate_cache2[e],
|
|
||||||
# a2_scale[e] if a2_scale is not None else None,
|
|
||||||
# per_act_token, self.block_shape)
|
|
||||||
qintermediate_cache2[
|
|
||||||
e, :
|
|
||||||
num_tokens, :], tmp_scale = per_token_group_quant_fp8(
|
|
||||||
intermediate_cache2[e, :num_tokens], block_n)
|
|
||||||
a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
|
||||||
else:
|
|
||||||
qintermediate_cache2 = intermediate_cache2
|
|
||||||
a2q_scale = a2_scale
|
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||||
B=w2,
|
B=w2,
|
||||||
|
|||||||
@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool = False,
|
||||||
per_channel_quant: bool,
|
per_channel_quant: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
block_m: Optional[int] = None,
|
block_m: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -192,7 +192,7 @@ class MoEConfig:
|
|||||||
num_local_experts: int
|
num_local_experts: int
|
||||||
moe_parallel_config: FusedMoEParallelConfig
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
in_dtype: torch.dtype # The activation type.
|
in_dtype: torch.dtype # The post quantization activation type.
|
||||||
|
|
||||||
# TODO: add more quantization params, blocked, per-token, etc.
|
# TODO: add more quantization params, blocked, per-token, etc.
|
||||||
block_size: int = 128
|
block_size: int = 128
|
||||||
@ -489,22 +489,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonExperts %s", self.moe)
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
experts = TritonExperts(
|
experts = TritonExperts()
|
||||||
use_fp8_w8a8=False,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=None,
|
|
||||||
per_channel_quant=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@ -827,8 +815,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
logger.debug(f"PARAM DTYPE = {params_dtype}")
|
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
|
||||||
#assert params_dtype.itemsize == 1
|
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
@ -838,7 +825,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
in_dtype=moe.in_dtype,
|
in_dtype=moe.in_dtype,
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
@ -877,15 +863,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
if self.moe_parallel_config.use_pplx_kernels:
|
if self.moe_parallel_config.use_pplx_kernels:
|
||||||
act_dtype = vllm_config.model_config.dtype
|
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||||
dtype=act_dtype,
|
dtype=moe.in_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
self.batched_router_logits = torch.zeros(
|
||||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||||
dtype=act_dtype,
|
dtype=moe.in_dtype,
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -782,11 +782,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
use_fp8_w8a8=True,
|
qtype=torch.float8_e4m3fn,
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
per_act_token=False, #?
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonOrDeepGemmExperts(fp8)")
|
logger.debug("TritonOrDeepGemmExperts(fp8)")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user