mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 04:07:03 +08:00
[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. (#19717)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
f59fc60fb3
commit
015fab8c2f
@ -29,7 +29,10 @@ MNK_FACTORS = [
|
|||||||
(224, 1024, 1536),
|
(224, 1024, 1536),
|
||||||
(224, 3072, 1024),
|
(224, 3072, 1024),
|
||||||
(224, 3072, 1536),
|
(224, 3072, 1536),
|
||||||
(1024 * 128, 1024, 1024),
|
(32768, 1024, 1024),
|
||||||
|
# These sizes trigger wrong answers.
|
||||||
|
#(7232, 2048, 5120),
|
||||||
|
#(40000, 2048, 5120),
|
||||||
]
|
]
|
||||||
|
|
||||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||||
@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
topk: int,
|
topk: int,
|
||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_ch)
|
per_out_ch)
|
||||||
@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
|||||||
topk: int,
|
topk: int,
|
||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
dtype = torch.half
|
dtype = torch.half
|
||||||
|
|
||||||
@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP(
|
|||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
per_out_channel: bool,
|
per_out_channel: bool,
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_channel)
|
per_out_channel)
|
||||||
|
|||||||
@ -4,6 +4,9 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_moe.py`.
|
Run `pytest tests/kernels/test_moe.py`.
|
||||||
"""
|
"""
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
@ -14,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, modular_triton_fused_moe)
|
fused_topk, modular_triton_fused_moe)
|
||||||
@ -40,7 +44,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
|
|||||||
vllm_config.scheduler_config.max_model_len = 8192
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
def run_moe_test(
|
||||||
|
baseline: Union[Callable, torch.Tensor],
|
||||||
|
moe_fn: Callable,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
padding: bool = False,
|
||||||
|
use_compile: bool = False,
|
||||||
|
use_cudagraph: bool = False,
|
||||||
|
atol: float = 2e-2,
|
||||||
|
rtol: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(baseline, torch.Tensor):
|
||||||
|
baseline_output = baseline
|
||||||
|
else:
|
||||||
|
baseline_output = baseline(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
# Pad the weight if moe padding is enabled
|
||||||
|
if padding:
|
||||||
|
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||||
|
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||||
|
|
||||||
|
if use_compile:
|
||||||
|
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
|
||||||
|
torch._dynamo.mark_dynamic(a, 0)
|
||||||
|
torch._dynamo.mark_dynamic(score, 0)
|
||||||
|
|
||||||
|
test_output = moe_fn(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
if use_cudagraph:
|
||||||
|
test_output.fill_(0)
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
|
test_output = moe_fn(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
torch.testing.assert_close(test_output,
|
||||||
|
baseline_output,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol)
|
||||||
|
|
||||||
|
return baseline_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@ -48,6 +121,7 @@ vllm_config.scheduler_config.max_model_len = 8192
|
|||||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("padding", [True, False])
|
@pytest.mark.parametrize("padding", [True, False])
|
||||||
|
@pytest.mark.parametrize("chunk_size", [8192])
|
||||||
def test_fused_moe(
|
def test_fused_moe(
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -57,7 +131,17 @@ def test_fused_moe(
|
|||||||
ep_size: int,
|
ep_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
padding: bool,
|
padding: bool,
|
||||||
|
chunk_size: int,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||||
|
|
||||||
|
#
|
||||||
|
# Setup test data
|
||||||
|
#
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
@ -77,58 +161,70 @@ def test_fused_moe(
|
|||||||
else:
|
else:
|
||||||
e_map = None
|
e_map = None
|
||||||
|
|
||||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
|
#
|
||||||
use_int8_w8a8=False,
|
# Setup test functions
|
||||||
use_int8_w8a16=False,
|
#
|
||||||
use_int4_w4a16=False,
|
|
||||||
per_channel_quant=False,
|
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||||
block_shape=None)
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
per_channel_quant=False,
|
||||||
|
block_shape=None)
|
||||||
|
|
||||||
|
def m_fused_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
return m_fused_moe_fn(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Run tests
|
||||||
|
#
|
||||||
|
runner = functools.partial(
|
||||||
|
run_moe_test,
|
||||||
|
a=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
score=score,
|
||||||
|
topk=topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: for now use_compile will error out if the problem size is
|
||||||
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
# setup code in case we are able to revisit this later.
|
||||||
|
use_compile = False
|
||||||
|
|
||||||
|
use_cudagraph = (n >= 1024 and k >= 1024
|
||||||
|
and current_platform.is_cuda_alike())
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
baseline_output = runner(torch_moe, iterative_moe)
|
||||||
iterative_output = iterative_moe(a,
|
runner(baseline_output,
|
||||||
w1,
|
fused_moe_fn,
|
||||||
w2,
|
use_compile=use_compile,
|
||||||
score,
|
use_cudagraph=use_cudagraph)
|
||||||
topk,
|
runner(baseline_output,
|
||||||
global_num_experts=e,
|
m_fused_moe,
|
||||||
expert_map=e_map,
|
use_compile=use_compile,
|
||||||
renormalize=False)
|
use_cudagraph=use_cudagraph)
|
||||||
|
|
||||||
# Pad the weight if moe padding is enabled
|
|
||||||
if padding:
|
|
||||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
triton_output = fused_moe(a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
global_num_experts=e,
|
|
||||||
expert_map=e_map,
|
|
||||||
renormalize=False)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
||||||
m_triton_output = m_fused_moe(a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
global_num_experts=e,
|
|
||||||
expert_map=e_map)
|
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
|
||||||
torch.testing.assert_close(m_triton_output,
|
|
||||||
torch_output,
|
|
||||||
atol=2e-2,
|
|
||||||
rtol=0)
|
|
||||||
torch.testing.assert_close(iterative_output,
|
|
||||||
torch_output,
|
|
||||||
atol=2e-2,
|
|
||||||
rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||||
@ -238,7 +334,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
w1_zp=w1_qzeros if has_zp else None,
|
w1_zp=w1_qzeros if has_zp else None,
|
||||||
w2_zp=w2_qzeros if has_zp else None,
|
w2_zp=w2_qzeros if has_zp else None,
|
||||||
block_shape=[0, group_size])
|
block_shape=[0, group_size])
|
||||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
torch_output = torch_moe(a,
|
||||||
|
w1_ref,
|
||||||
|
w2_ref,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
expert_map=e_map)
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
@ -265,45 +366,51 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
pytest.skip("AITER ROCm test skip for float32")
|
pytest.skip("AITER ROCm test skip for float32")
|
||||||
|
|
||||||
# Instantiate our and huggingface's MoE blocks
|
# Instantiate our and huggingface's MoE blocks
|
||||||
config = MixtralConfig()
|
vllm_config.compilation_config.static_forward_context = dict()
|
||||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
with (set_current_vllm_config(vllm_config),
|
||||||
vllm_moe = MixtralMoE(
|
set_forward_context(None, vllm_config)):
|
||||||
num_experts=config.num_local_experts,
|
config = MixtralConfig()
|
||||||
top_k=config.num_experts_per_tok,
|
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||||
hidden_size=config.hidden_size,
|
vllm_moe = MixtralMoE(
|
||||||
intermediate_size=config.intermediate_size,
|
num_experts=config.num_local_experts,
|
||||||
params_dtype=dtype,
|
top_k=config.num_experts_per_tok,
|
||||||
tp_size=1,
|
hidden_size=config.hidden_size,
|
||||||
dp_size=1,
|
intermediate_size=config.intermediate_size,
|
||||||
).cuda()
|
params_dtype=dtype,
|
||||||
|
tp_size=1,
|
||||||
|
dp_size=1,
|
||||||
|
).cuda()
|
||||||
|
|
||||||
# Load the weights
|
# Load the weights
|
||||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||||
for i in range(config.num_local_experts):
|
for i in range(config.num_local_experts):
|
||||||
weights = (hf_moe.experts[i].w1.weight.data,
|
weights = (hf_moe.experts[i].w1.weight.data,
|
||||||
hf_moe.experts[i].w3.weight.data)
|
hf_moe.experts[i].w3.weight.data)
|
||||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||||
|
|
||||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
hf_inputs = torch.randn(
|
||||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
(1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||||
|
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||||
|
|
||||||
# Pad the weight if moe padding is enabled
|
# Pad the weight if moe padding is enabled
|
||||||
if padding:
|
if padding:
|
||||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
|
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
|
||||||
requires_grad=False)
|
0:-128],
|
||||||
torch.cuda.empty_cache()
|
requires_grad=False)
|
||||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
torch.cuda.empty_cache()
|
||||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||||
requires_grad=False)
|
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
|
||||||
torch.cuda.empty_cache()
|
0:-128],
|
||||||
|
requires_grad=False)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Run forward passes for both MoE blocks
|
# Run forward passes for both MoE blocks
|
||||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||||
|
|
||||||
mixtral_moe_tol = {
|
mixtral_moe_tol = {
|
||||||
torch.float32: 1e-3,
|
torch.float32: 1e-3,
|
||||||
@ -546,7 +653,12 @@ def test_fused_marlin_moe(
|
|||||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
torch_output = torch_moe(a,
|
||||||
|
w_ref1,
|
||||||
|
w_ref2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
expert_map=e_map)
|
||||||
|
|
||||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||||
a,
|
a,
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
device=w2.device,
|
device=w2.device,
|
||||||
block_size=quant_blocksize)
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||||
|
|
||||||
torch.testing.assert_close(torch_output,
|
torch.testing.assert_close(torch_output,
|
||||||
cutlass_output,
|
cutlass_output,
|
||||||
|
|||||||
@ -6,9 +6,9 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
@ -164,22 +164,6 @@ vllm_config.scheduler_config.max_num_seqs = 128
|
|||||||
vllm_config.scheduler_config.max_model_len = 8192
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
|
|
||||||
M, K = a.shape
|
|
||||||
topk = topk_ids.shape[1]
|
|
||||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
|
||||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
num_experts = w1.shape[0]
|
|
||||||
for i in range(num_experts):
|
|
||||||
mask = (topk_ids == i).view(-1)
|
|
||||||
if mask.sum():
|
|
||||||
out[mask] = SiluAndMul()(
|
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
|
||||||
|
|
||||||
return (out.view(M, -1, w2.shape[1]) *
|
|
||||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def _pplx_moe(
|
def _pplx_moe(
|
||||||
pgi: ProcessGroupInfo,
|
pgi: ProcessGroupInfo,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
@ -210,8 +194,8 @@ def _pplx_moe(
|
|||||||
group_name = cpu_group.group_name
|
group_name = cpu_group.group_name
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
|
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
|
||||||
topk_ids)
|
topk_ids)
|
||||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||||
w2_scale, topk_weights, topk_ids,
|
w2_scale, topk_weights, topk_ids,
|
||||||
a1_scale, out_dtype, per_act_token,
|
a1_scale, out_dtype, per_act_token,
|
||||||
|
|||||||
@ -18,8 +18,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_pplx = False
|
has_pplx = False
|
||||||
|
|
||||||
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||||
@ -163,29 +163,6 @@ def batched_moe(
|
|||||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||||
|
|
||||||
|
|
||||||
# Note: same as torch_moe but with fused_topk factored out.
|
|
||||||
def torch_moe2(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weight: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
M, K = a.shape
|
|
||||||
topk = topk_ids.shape[1]
|
|
||||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
|
||||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
||||||
num_experts = w1.shape[0]
|
|
||||||
for i in range(num_experts):
|
|
||||||
mask = (topk_ids == i).view(-1)
|
|
||||||
if mask.sum():
|
|
||||||
out[mask] = SiluAndMul()(
|
|
||||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
|
||||||
|
|
||||||
return (out.view(M, -1, w2.shape[1]) *
|
|
||||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||||
@ -209,7 +186,7 @@ def test_fused_moe_batched_experts(
|
|||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
|
||||||
@ -409,7 +386,7 @@ def pplx_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_compile: bool = True,
|
use_compile: bool = False,
|
||||||
use_cudagraphs: bool = True,
|
use_cudagraphs: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
@ -470,10 +447,16 @@ def pplx_moe(
|
|||||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||||
|
|
||||||
|
# Note: for now use_compile will error out if the problem size is
|
||||||
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
# setup code in case we are able to revisit this later.
|
||||||
if use_compile:
|
if use_compile:
|
||||||
_fused_experts = torch.compile(fused_experts,
|
_fused_experts = torch.compile(fused_experts,
|
||||||
backend='inductor',
|
backend='inductor',
|
||||||
fullgraph=True)
|
fullgraph=True)
|
||||||
|
torch._dynamo.mark_dynamic(a_chunk, 0)
|
||||||
|
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
|
||||||
|
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
|
||||||
else:
|
else:
|
||||||
_fused_experts = fused_experts
|
_fused_experts = fused_experts
|
||||||
|
|
||||||
@ -576,7 +559,7 @@ def _pplx_moe(
|
|||||||
|
|
||||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||||
a, w1, w2, topk_weight, topk_ids)
|
a, w1, w2, topk_weight, topk_ids)
|
||||||
# TODO (bnell): fix + re-enable
|
# TODO (bnell): fix + re-enable
|
||||||
|
|||||||
@ -403,19 +403,24 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
|||||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||||
|
monkeypatch):
|
||||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
||||||
block_size = [block_m, block_m]
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
if topk > E:
|
if topk > E:
|
||||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||||
|
|
||||||
if not _valid_deep_gemm_shape(M, N, K):
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||||
|
|
||||||
|
chunk_size = 1024
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||||
|
|
||||||
|
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||||
|
block_size = [block_m, block_m]
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
@ -451,6 +456,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
|||||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||||
|
|
||||||
|
# Note: for now use_compile will error out if the problem size is
|
||||||
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
# setup code in case we are able to revisit this later.
|
||||||
|
use_compile = False
|
||||||
|
|
||||||
|
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||||
|
and current_platform.is_cuda_alike())
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
# Set the context to avoid lots of warning spam.
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
if M >= 128:
|
if M >= 128:
|
||||||
@ -463,7 +476,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
|||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
a, score.float(), topk, False)
|
a, score.float(), topk, False)
|
||||||
|
|
||||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
if use_compile:
|
||||||
|
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||||
|
backend="inductor",
|
||||||
|
fullgraph=True)
|
||||||
|
torch._dynamo.mark_dynamic(a, 0)
|
||||||
|
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||||
|
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||||
|
else:
|
||||||
|
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||||
|
|
||||||
|
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||||
|
topk_ids)
|
||||||
|
|
||||||
|
if use_cudagraph:
|
||||||
|
out.fill_(0)
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
|
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||||
|
topk_ids)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
#print(f"{out.sum()=}")
|
#print(f"{out.sum()=}")
|
||||||
#print(f"{ref_out.sum()=}")
|
#print(f"{ref_out.sum()=}")
|
||||||
|
|||||||
@ -1054,12 +1054,21 @@ def compute_max_diff(output, output_ref):
|
|||||||
torch.abs(output_ref))
|
torch.abs(output_ref))
|
||||||
|
|
||||||
|
|
||||||
def torch_moe(a, w1, w2, score, topk, expert_map):
|
def torch_experts(a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
assert (global_num_experts == -1
|
||||||
|
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||||
|
or (expert_map is not None
|
||||||
|
and global_num_experts == expert_map.shape[0]))
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
B, D = a.shape
|
B, D = a.shape
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
|
||||||
topk_weight = topk_weight.view(-1)
|
topk_weight = topk_weight.view(-1)
|
||||||
topk_ids = topk_ids.view(-1)
|
topk_ids = topk_ids.view(-1)
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
@ -1073,6 +1082,19 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
|
|||||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_moe(a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
|
return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts,
|
||||||
|
expert_map)
|
||||||
|
|
||||||
|
|
||||||
def torch_moe_single(a, w, score, topk):
|
def torch_moe_single(a, w, score, topk):
|
||||||
B, D = a.shape
|
B, D = a.shape
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
|||||||
@ -981,6 +981,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_DP_RANK",
|
"VLLM_DP_RANK",
|
||||||
"VLLM_DP_SIZE",
|
"VLLM_DP_SIZE",
|
||||||
"VLLM_USE_STANDALONE_COMPILE",
|
"VLLM_USE_STANDALONE_COMPILE",
|
||||||
|
"VLLM_FUSED_MOE_CHUNK_SIZE",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
if key in environment_variables:
|
if key in environment_variables:
|
||||||
|
|||||||
@ -41,24 +41,24 @@ def run_cutlass_moe_fp8(
|
|||||||
assert w1.dtype == torch.float8_e4m3fn
|
assert w1.dtype == torch.float8_e4m3fn
|
||||||
assert w2.dtype == torch.float8_e4m3fn
|
assert w2.dtype == torch.float8_e4m3fn
|
||||||
if expert_num_tokens is None:
|
if expert_num_tokens is None:
|
||||||
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
|
assert a1q.size(1) == w1.size(2), "Hidden size mismatch w1"
|
||||||
else:
|
else:
|
||||||
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
|
assert a1q.size(2) == w1.size(2), "Hidden size mismatch w1"
|
||||||
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
|
assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2"
|
||||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
assert w1_scale.dim() == 1 or w1_scale.size(
|
||||||
1] == w1.shape[1], "W1 scale shape mismatch"
|
1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch"
|
||||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
assert w2_scale.dim() == 1 or w2_scale.size(
|
||||||
1] == w2.shape[1], "W2 scale shape mismatch"
|
1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch"
|
||||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
assert w1.size(0) == w2.size(0), "Expert number mismatch"
|
||||||
assert a1q_scale is None or a1q_scale.dim(
|
assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size(
|
||||||
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
0) == 1 or a1q_scale.size(
|
||||||
0], "Input scale shape mismatch"
|
0) == a1q.shape[0], "Input scale shape mismatch"
|
||||||
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
assert w1.size(0) == w2.size(0), "Weights expert number mismatch"
|
||||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch"
|
||||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch"
|
||||||
assert a2_scale is None or a2_scale.dim(
|
assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size(
|
||||||
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
|
0) == 1 or a2_scale.size(
|
||||||
0], "Intermediate scale shape mismatch"
|
0) == a1q.shape[0], "Intermediate scale shape mismatch"
|
||||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
assert expert_num_tokens is None
|
assert expert_num_tokens is None
|
||||||
@ -75,12 +75,12 @@ def run_cutlass_moe_fp8(
|
|||||||
# their tokens are already contiguous for each expert as a result of
|
# their tokens are already contiguous for each expert as a result of
|
||||||
# the dispatch function.
|
# the dispatch function.
|
||||||
|
|
||||||
M = a1q.shape[0] # non batched expert M
|
M = a1q.size(0) # non batched expert M
|
||||||
padded_M = a1q.shape[1] # batched expert M
|
padded_M = a1q.size(1) # batched expert M
|
||||||
_, K, N = w2.shape
|
_, K, N = w2.shape
|
||||||
device = a1q.device
|
device = a1q.device
|
||||||
|
|
||||||
assert w1.shape[2] == K
|
assert w1.size(2) == K
|
||||||
assert global_num_experts != -1
|
assert global_num_experts != -1
|
||||||
assert a1q_scale is not None
|
assert a1q_scale is not None
|
||||||
|
|
||||||
@ -91,8 +91,8 @@ def run_cutlass_moe_fp8(
|
|||||||
else:
|
else:
|
||||||
local_topk_ids = topk_ids
|
local_topk_ids = topk_ids
|
||||||
|
|
||||||
topk = local_topk_ids.shape[1]
|
topk = local_topk_ids.size(1)
|
||||||
local_E = w1.shape[0]
|
local_E = w1.size(0)
|
||||||
|
|
||||||
if use_batched_format:
|
if use_batched_format:
|
||||||
assert expert_num_tokens is not None
|
assert expert_num_tokens is not None
|
||||||
@ -111,10 +111,10 @@ def run_cutlass_moe_fp8(
|
|||||||
problem_sizes2, expert_num_tokens,
|
problem_sizes2, expert_num_tokens,
|
||||||
local_E, padded_M, N, K)
|
local_E, padded_M, N, K)
|
||||||
|
|
||||||
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
|
w1_scale = w1_scale.reshape(w1_scale.size(0), -1)
|
||||||
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
|
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
||||||
a1q = a1q.reshape(-1, a1q.shape[2])
|
a1q = a1q.reshape(-1, a1q.size(2))
|
||||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
|
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
expert_offsets = torch.empty((global_num_experts + 1),
|
expert_offsets = torch.empty((global_num_experts + 1),
|
||||||
@ -151,19 +151,19 @@ def run_cutlass_moe_fp8(
|
|||||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||||
expert_offsets = expert_offsets[:-1]
|
expert_offsets = expert_offsets[:-1]
|
||||||
|
|
||||||
ab_strides1 = torch.full((w1.shape[0], ),
|
ab_strides1 = torch.full((w1.size(0), ),
|
||||||
K,
|
K,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
c_strides1 = torch.full((w1.shape[0], ),
|
c_strides1 = torch.full((w1.size(0), ),
|
||||||
2 * N,
|
2 * N,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
ab_strides2 = torch.full((w1.shape[0], ),
|
ab_strides2 = torch.full((w1.size(0), ),
|
||||||
N,
|
N,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
c_strides2 = torch.full((w1.shape[0], ),
|
c_strides2 = torch.full((w1.size(0), ),
|
||||||
K,
|
K,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
@ -237,7 +237,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: tuple[int, ...] = ()
|
workspace2: tuple[int, ...] = ()
|
||||||
output: tuple[int, ...] = ()
|
output: tuple[int, ...] = ()
|
||||||
if self.use_batched_format:
|
if self.use_batched_format:
|
||||||
padded_M = aq.shape[1]
|
padded_M = aq.size(1)
|
||||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||||
output = (self.max_experts_per_worker, padded_M, K)
|
output = (self.max_experts_per_worker, padded_M, K)
|
||||||
@ -332,7 +332,7 @@ def cutlass_moe_fp8(
|
|||||||
"""
|
"""
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
per_out_ch = w1_scale.numel() != w1_q.shape[0]
|
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||||
|
|
||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
|
|
||||||
@ -425,11 +425,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
|||||||
assert (m == m_a), "input shape mismatch"
|
assert (m == m_a), "input shape mismatch"
|
||||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||||
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
|
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
||||||
== m), ("topk must be provided for each row of a")
|
== m), ("topk must be provided for each row of a")
|
||||||
|
|
||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
num_topk = topk_ids.shape[1]
|
num_topk = topk_ids.size(1)
|
||||||
|
|
||||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||||
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||||
@ -463,7 +463,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
|||||||
out_dtype, device)
|
out_dtype, device)
|
||||||
del rep_a_fp4, rep_a_blockscale
|
del rep_a_fp4, rep_a_blockscale
|
||||||
# hidden size dimension is split to one halfpytho sized tensor.
|
# hidden size dimension is split to one halfpytho sized tensor.
|
||||||
intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2),
|
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=out_dtype)
|
dtype=out_dtype)
|
||||||
|
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|||||||
M = hidden_states.size(0)
|
M = hidden_states.size(0)
|
||||||
_, K, N = w2.size()
|
_, K, N = w2.size()
|
||||||
if not _valid_deep_gemm_shape(M, N, K):
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
logger.debug("DeepGemm disabled: unalinged problem size.")
|
logger.debug("DeepGemm disabled: unaligned problem size.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
|
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def dequant_fp8(expert_x_fp8: torch.Tensor,
|
|||||||
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
|
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
|
||||||
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
|
num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE)
|
||||||
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
|
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
|
||||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape)
|
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||||
|
|
||||||
|
|
||||||
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|||||||
@ -488,10 +488,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
|
|
||||||
if use_fp8_w8a8 or use_int8_w8a8:
|
if use_fp8_w8a8 or use_int8_w8a8:
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
assert (block_shape is None
|
||||||
== B_scale.shape[-2])
|
or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2))
|
||||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
assert (block_shape is None
|
||||||
== B_scale.shape[-1])
|
or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1))
|
||||||
|
|
||||||
elif use_int8_w8a16 or use_int4_w4a16:
|
elif use_int8_w8a16 or use_int4_w4a16:
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
@ -500,19 +500,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
assert A_scale is None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
|
|
||||||
M = A.shape[0]
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
|
|
||||||
EM = sorted_token_ids.shape[0]
|
EM = sorted_token_ids.size(0)
|
||||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||||
# optimize for small batch_size.
|
# optimize for small batch_size.
|
||||||
# We assume that top_ids of each token is unique, so
|
# We assume that top_ids of each token is unique, so
|
||||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||||
# and we can skip some invalid blocks.
|
# and we can skip some invalid blocks.
|
||||||
EM = min(sorted_token_ids.shape[0],
|
EM = min(sorted_token_ids.size(0),
|
||||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
A.size(0) * top_k * config['BLOCK_SIZE_M'])
|
||||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
B.size(1), META['BLOCK_SIZE_N']), )
|
||||||
|
|
||||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||||
block_shape is not None and block_shape[1] > 0:
|
block_shape is not None and block_shape[1] > 0:
|
||||||
@ -522,16 +522,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||||
num_valid_tokens=num_tokens,
|
num_valid_tokens=num_tokens,
|
||||||
group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
num_experts=B.shape[0],
|
num_experts=B.size(0),
|
||||||
bit=4 if use_int4_w4a16 else 8)
|
bit=4 if use_int4_w4a16 else 8)
|
||||||
config = config.copy()
|
config = config.copy()
|
||||||
config.update(
|
config.update(
|
||||||
get_moe_wna16_block_config(config=config,
|
get_moe_wna16_block_config(config=config,
|
||||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||||
num_valid_tokens=num_tokens,
|
num_valid_tokens=num_tokens,
|
||||||
size_k=A.shape[1],
|
size_k=A.size(1),
|
||||||
size_n=B.shape[1],
|
size_n=B.size(1),
|
||||||
num_experts=B.shape[1],
|
num_experts=B.size(1),
|
||||||
group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
real_top_k=top_k,
|
real_top_k=top_k,
|
||||||
block_size_m=config["BLOCK_SIZE_M"]))
|
block_size_m=config["BLOCK_SIZE_M"]))
|
||||||
@ -556,8 +556,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.shape[1],
|
B.size(1),
|
||||||
A.shape[1],
|
A.size(1),
|
||||||
EM,
|
EM,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
@ -573,7 +573,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
B_zp.stride(0) if B_zp is not None else 0,
|
B_zp.stride(0) if B_zp is not None else 0,
|
||||||
B_zp.stride(2) if B_zp is not None else 0,
|
B_zp.stride(2) if B_zp is not None else 0,
|
||||||
B_zp.stride(1) if B_zp is not None else 0,
|
B_zp.stride(1) if B_zp is not None else 0,
|
||||||
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||||
group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -599,8 +599,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.shape[1],
|
B.size(1),
|
||||||
B.shape[2],
|
B.size(2),
|
||||||
EM,
|
EM,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
@ -818,7 +818,7 @@ def try_get_optimal_moe_config(
|
|||||||
M: int,
|
M: int,
|
||||||
is_marlin: bool = False,
|
is_marlin: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
) -> dict[str, int]:
|
||||||
from vllm.model_executor.layers.fused_moe import get_config
|
from vllm.model_executor.layers.fused_moe import get_config
|
||||||
override_config = get_config()
|
override_config = get_config()
|
||||||
if override_config:
|
if override_config:
|
||||||
@ -873,10 +873,10 @@ def fused_topk(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
indices_type: Optional[torch.dtype] = None,
|
indices_type: Optional[torch.dtype] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.size(0) == gating_output.size(0), (
|
||||||
"Number of tokens mismatch")
|
"Number of tokens mismatch")
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
M, _ = hidden_states.size()
|
||||||
|
|
||||||
topk_weights = torch.empty(M,
|
topk_weights = torch.empty(M,
|
||||||
topk,
|
topk,
|
||||||
@ -915,7 +915,7 @@ def grouped_topk(
|
|||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.size(0) == gating_output.size(0), (
|
||||||
"Number of tokens mismatch")
|
"Number of tokens mismatch")
|
||||||
|
|
||||||
if scoring_func == "softmax":
|
if scoring_func == "softmax":
|
||||||
@ -925,7 +925,7 @@ def grouped_topk(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||||
|
|
||||||
num_token = scores.shape[0]
|
num_token = scores.size(0)
|
||||||
if e_score_correction_bias is not None:
|
if e_score_correction_bias is not None:
|
||||||
# Store original scores before applying correction bias. We use biased
|
# Store original scores before applying correction bias. We use biased
|
||||||
# scores for expert selection but original scores for routing weights
|
# scores for expert selection but original scores for routing weights
|
||||||
@ -942,7 +942,7 @@ def grouped_topk(
|
|||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
score_mask = group_mask.unsqueeze(-1).expand(
|
score_mask = group_mask.unsqueeze(-1).expand(
|
||||||
num_token, num_expert_group,
|
num_token, num_expert_group,
|
||||||
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
|
scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e]
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
tmp_scores = scores.masked_fill(~score_mask.bool(),
|
||||||
float("-inf")) # [n, e]
|
float("-inf")) # [n, e]
|
||||||
|
|
||||||
@ -1162,7 +1162,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||||
# For now, disable DeepGemm for small N (<= 512) until better
|
# For now, disable DeepGemm for small N (<= 512) until better
|
||||||
# permute/unpermute ops are available.
|
# permute/unpermute ops are available.
|
||||||
N = w1.shape[1]
|
N = w1.size(1)
|
||||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
@ -1233,13 +1233,13 @@ def fused_experts_impl(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if use_int4_w4a16:
|
if use_int4_w4a16:
|
||||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
assert hidden_states.size(1) // 2 == w1.size(2), (
|
||||||
2], "Hidden size mismatch"
|
"Hidden size mismatch")
|
||||||
else:
|
else:
|
||||||
assert hidden_states.shape[1] == w1.shape[2], (
|
assert hidden_states.size(1) == w1.size(2), (
|
||||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}")
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
@ -1247,12 +1247,12 @@ def fused_experts_impl(
|
|||||||
torch.float32, torch.float16, torch.bfloat16
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
]
|
]
|
||||||
|
|
||||||
num_tokens = hidden_states.shape[0]
|
num_tokens = hidden_states.size(0)
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.size()
|
||||||
K = w2.shape[1]
|
K = w2.size(1)
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = E
|
global_num_experts = E
|
||||||
top_k_num = topk_ids.shape[1]
|
top_k_num = topk_ids.size(1)
|
||||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
# https://github.com/vllm-project/vllm/issues/5938
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
@ -1269,8 +1269,8 @@ def fused_experts_impl(
|
|||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.size(),
|
||||||
w2.shape,
|
w2.size(),
|
||||||
top_k_num,
|
top_k_num,
|
||||||
config_dtype,
|
config_dtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
@ -1310,7 +1310,7 @@ def fused_experts_impl(
|
|||||||
min((chunk + 1) * CHUNK_SIZE,
|
min((chunk + 1) * CHUNK_SIZE,
|
||||||
num_tokens))
|
num_tokens))
|
||||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||||
|
|
||||||
if tokens_in_chunk == 0:
|
if tokens_in_chunk == 0:
|
||||||
break
|
break
|
||||||
@ -1322,7 +1322,7 @@ def fused_experts_impl(
|
|||||||
# do not need to be adjusted.
|
# do not need to be adjusted.
|
||||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||||
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
|
||||||
topk_ids.shape[1]]
|
topk_ids.size(1)]
|
||||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||||
config = get_config_func(tokens_in_chunk)
|
config = get_config_func(tokens_in_chunk)
|
||||||
|
|
||||||
@ -1398,7 +1398,7 @@ def fused_experts_impl(
|
|||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||||
|
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
@ -1611,8 +1611,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
config = try_get_optimal_moe_config(
|
config = try_get_optimal_moe_config(
|
||||||
w1.shape,
|
w1.size(),
|
||||||
w2.shape,
|
w2.size(),
|
||||||
top_k_num,
|
top_k_num,
|
||||||
config_dtype,
|
config_dtype,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
|
|||||||
@ -861,13 +861,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.global_num_experts = num_experts
|
self.global_num_experts = num_experts
|
||||||
|
|
||||||
# For smuggling this layer into the fused moe custom op
|
# For smuggling this layer into the fused moe custom op
|
||||||
self.use_direct_call = self.dp_size == 1
|
compilation_config = vllm_config.compilation_config
|
||||||
if not self.use_direct_call:
|
if prefix in compilation_config.static_forward_context:
|
||||||
compilation_config = vllm_config.compilation_config
|
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||||
if prefix in compilation_config.static_forward_context:
|
compilation_config.static_forward_context[prefix] = self
|
||||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
self.layer_name = prefix
|
||||||
compilation_config.static_forward_context[prefix] = self
|
|
||||||
self.layer_name = prefix
|
|
||||||
|
|
||||||
# Determine expert maps
|
# Determine expert maps
|
||||||
if self.use_ep:
|
if self.use_ep:
|
||||||
@ -1361,11 +1359,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
if self.use_direct_call:
|
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||||
return self.forward_impl(hidden_states, router_logits)
|
self.layer_name)
|
||||||
else:
|
|
||||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
|
||||||
self.layer_name)
|
|
||||||
|
|
||||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||||
full_router_logits: torch.Tensor):
|
full_router_logits: torch.Tensor):
|
||||||
|
|||||||
@ -69,7 +69,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
repeat_cols = 4
|
repeat_cols = 4
|
||||||
repeat_rows = 1 if self.per_act_token else a1.shape[0]
|
repeat_rows = 1 if self.per_act_token else a1.size(0)
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
||||||
self.per_act_token, self.block_shape)
|
self.per_act_token, self.block_shape)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user