mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
Expert Parallelism (EP) Support for DeepSeek V2 (#12583)
This commit is contained in:
parent
7940d8a6a7
commit
781096e385
@ -468,7 +468,8 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||||
|
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
|||||||
227
tests/distributed/test_expert_parallel.py
Normal file
227
tests/distributed/test_expert_parallel.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Literal, NamedTuple, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import TaskOption
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||||
|
|
||||||
|
logger = init_logger("test_expert_parallel")
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelSetup(NamedTuple):
|
||||||
|
tp_size: int
|
||||||
|
eager_mode: bool
|
||||||
|
chunked_prefill: bool
|
||||||
|
|
||||||
|
|
||||||
|
class EPTestOptions(NamedTuple):
|
||||||
|
trust_remote_code: bool
|
||||||
|
tokenizer_mode: Optional[str]
|
||||||
|
load_format: Optional[str] = None
|
||||||
|
hf_overrides: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EPTestSettings:
|
||||||
|
parallel_setups: List[ParallelSetup]
|
||||||
|
distributed_backends: List[str]
|
||||||
|
task: TaskOption
|
||||||
|
test_options: EPTestOptions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detailed(
|
||||||
|
*,
|
||||||
|
tp_base: int = 2,
|
||||||
|
task: TaskOption = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_mode: Optional[str] = None,
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
hf_overrides: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return EPTestSettings(
|
||||||
|
parallel_setups=[
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=False),
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=True),
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
eager_mode=True,
|
||||||
|
chunked_prefill=False),
|
||||||
|
ParallelSetup(tp_size=2 * tp_base,
|
||||||
|
eager_mode=False,
|
||||||
|
chunked_prefill=True),
|
||||||
|
ParallelSetup(tp_size=2 * tp_base,
|
||||||
|
eager_mode=True,
|
||||||
|
chunked_prefill=False),
|
||||||
|
],
|
||||||
|
distributed_backends=["mp", "ray"],
|
||||||
|
task=task,
|
||||||
|
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||||
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
load_format=load_format,
|
||||||
|
hf_overrides=hf_overrides),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fast(
|
||||||
|
*,
|
||||||
|
tp_base: int = 2,
|
||||||
|
task: TaskOption = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tokenizer_mode: Optional[str] = None,
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
hf_overrides: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return EPTestSettings(
|
||||||
|
parallel_setups=[
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
eager_mode=True,
|
||||||
|
chunked_prefill=False),
|
||||||
|
],
|
||||||
|
distributed_backends=["mp"],
|
||||||
|
task=task,
|
||||||
|
test_options=EPTestOptions(trust_remote_code=trust_remote_code,
|
||||||
|
tokenizer_mode=tokenizer_mode,
|
||||||
|
load_format=load_format,
|
||||||
|
hf_overrides=hf_overrides),
|
||||||
|
)
|
||||||
|
|
||||||
|
def iter_params(self, model_name: str):
|
||||||
|
opts = self.test_options
|
||||||
|
|
||||||
|
for parallel_setup in self.parallel_setups:
|
||||||
|
for distributed_backend in self.distributed_backends:
|
||||||
|
yield (model_name, parallel_setup, distributed_backend,
|
||||||
|
self.task, opts)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: You can adjust tp_base locally to fit the model in GPU
|
||||||
|
# The values displayed here are only a rough indicator of the size of the model
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
TEST_MODELS = {
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(
|
||||||
|
trust_remote_code=True),
|
||||||
|
"mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_tp(
|
||||||
|
model_name: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
|
test_options: EPTestOptions,
|
||||||
|
num_gpus_available: int,
|
||||||
|
*,
|
||||||
|
method: Literal["generate"],
|
||||||
|
):
|
||||||
|
(
|
||||||
|
tp_size,
|
||||||
|
eager_mode,
|
||||||
|
chunked_prefill,
|
||||||
|
) = parallel_setup
|
||||||
|
(
|
||||||
|
trust_remote_code,
|
||||||
|
tokenizer_mode,
|
||||||
|
load_format,
|
||||||
|
hf_overrides,
|
||||||
|
) = test_options
|
||||||
|
|
||||||
|
if num_gpus_available < tp_size:
|
||||||
|
pytest.skip(f"Need at least {tp_size} GPUs")
|
||||||
|
|
||||||
|
common_args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"8",
|
||||||
|
"--load-format",
|
||||||
|
"auto",
|
||||||
|
]
|
||||||
|
if chunked_prefill:
|
||||||
|
common_args.append("--enable-chunked-prefill")
|
||||||
|
if eager_mode:
|
||||||
|
common_args.append("--enforce-eager")
|
||||||
|
if task != "auto":
|
||||||
|
common_args.extend(["--task", task])
|
||||||
|
if trust_remote_code:
|
||||||
|
common_args.append("--trust-remote-code")
|
||||||
|
if tokenizer_mode:
|
||||||
|
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||||
|
if load_format:
|
||||||
|
common_args.extend(["--load-format", load_format])
|
||||||
|
if hf_overrides:
|
||||||
|
common_args.extend(["--hf-overrides", hf_overrides])
|
||||||
|
|
||||||
|
ep_env = {
|
||||||
|
"VLLM_TEST_ENABLE_EP": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
ep_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
distributed_backend,
|
||||||
|
]
|
||||||
|
|
||||||
|
# compare without expert parallelism
|
||||||
|
tp_env = {
|
||||||
|
"VLLM_TEST_ENABLE_EP": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
tp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
"mp",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
compare_two_settings(model_name,
|
||||||
|
ep_args,
|
||||||
|
tp_args,
|
||||||
|
ep_env,
|
||||||
|
tp_env,
|
||||||
|
method=method,
|
||||||
|
max_wait_seconds=360)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
|
"test_options"),
|
||||||
|
[
|
||||||
|
params for model_name, settings in TEST_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_name)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_ep(
|
||||||
|
model_name: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
|
test_options: EPTestOptions,
|
||||||
|
num_gpus_available,
|
||||||
|
):
|
||||||
|
_compare_tp(model_name,
|
||||||
|
parallel_setup,
|
||||||
|
distributed_backend,
|
||||||
|
task,
|
||||||
|
test_options,
|
||||||
|
num_gpus_available,
|
||||||
|
method="generate")
|
||||||
@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
|
|||||||
num_bits=num_bits,
|
num_bits=num_bits,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_output = torch_moe(
|
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
||||||
a,
|
score, topk, None)
|
||||||
w_ref1.transpose(1, 2),
|
|
||||||
w_ref2.transpose(1, 2),
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
|
EP_SIZE = [1, 4]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ TOP_KS = [2, 6]
|
|||||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
def test_fused_moe(
|
def test_fused_moe(
|
||||||
m: int,
|
m: int,
|
||||||
@ -41,6 +43,7 @@ def test_fused_moe(
|
|||||||
k: int,
|
k: int,
|
||||||
e: int,
|
e: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
|
ep_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
@ -48,10 +51,38 @@ def test_fused_moe(
|
|||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
|
||||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
if ep_size > 1:
|
||||||
|
local_e = e // ep_size
|
||||||
|
e_ids = torch.randint(0,
|
||||||
|
e, (local_e, ),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32)
|
||||||
|
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||||
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||||
|
w1 = w1[e_ids]
|
||||||
|
w2 = w2[e_ids]
|
||||||
|
else:
|
||||||
|
e_map = None
|
||||||
|
|
||||||
|
triton_output = fused_moe(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
renormalize=False)
|
||||||
|
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
|
iterative_output = iterative_moe(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
renormalize=False)
|
||||||
torch.testing.assert_close(iterative_output,
|
torch.testing.assert_close(iterative_output,
|
||||||
torch_output,
|
torch_output,
|
||||||
atol=2e-2,
|
atol=2e-2,
|
||||||
@ -63,13 +94,14 @@ def test_fused_moe(
|
|||||||
@pytest.mark.parametrize("k", [128, 1024])
|
@pytest.mark.parametrize("k", [128, 1024])
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("group_size", [64, 128])
|
@pytest.mark.parametrize("group_size", [64, 128])
|
||||||
@pytest.mark.parametrize("has_zp", [True, False])
|
@pytest.mark.parametrize("has_zp", [True, False])
|
||||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||||
dtype: torch.dtype, group_size: int, has_zp: bool,
|
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||||
weight_bits: int):
|
has_zp: bool, weight_bits: int):
|
||||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
@ -130,6 +162,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
if has_zp:
|
if has_zp:
|
||||||
w_qzeros[expert_id] = qzeros
|
w_qzeros[expert_id] = qzeros
|
||||||
|
|
||||||
|
if ep_size > 1:
|
||||||
|
local_e = e // ep_size
|
||||||
|
e_ids = torch.randint(0,
|
||||||
|
e, (local_e, ),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32)
|
||||||
|
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||||
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||||
|
w1_ref = w1_ref[e_ids]
|
||||||
|
w2_ref = w2_ref[e_ids]
|
||||||
|
w1_qweight = w1_qweight[e_ids]
|
||||||
|
w2_qweight = w2_qweight[e_ids]
|
||||||
|
w1_scales = w1_scales[e_ids]
|
||||||
|
w2_scales = w2_scales[e_ids]
|
||||||
|
w1_qzeros = w1_qzeros[e_ids]
|
||||||
|
w2_qzeros = w2_qzeros[e_ids]
|
||||||
|
else:
|
||||||
|
e_map = None
|
||||||
|
|
||||||
triton_output = fused_moe(a,
|
triton_output = fused_moe(a,
|
||||||
w1_qweight,
|
w1_qweight,
|
||||||
w2_qweight,
|
w2_qweight,
|
||||||
@ -138,12 +189,14 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
renormalize=False,
|
renormalize=False,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
use_int4_w4a16=weight_bits == 4,
|
||||||
use_int8_w8a16=weight_bits == 8,
|
use_int8_w8a16=weight_bits == 8,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
w1_scale=w1_scales,
|
w1_scale=w1_scales,
|
||||||
w2_scale=w2_scales,
|
w2_scale=w2_scales,
|
||||||
w1_zp=w1_qzeros if has_zp else None,
|
w1_zp=w1_qzeros if has_zp else None,
|
||||||
w2_zp=w2_qzeros if has_zp else None,
|
w2_zp=w2_qzeros if has_zp else None,
|
||||||
block_shape=[0, group_size])
|
block_shape=[0, group_size])
|
||||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
|
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1053,7 +1053,7 @@ def compute_max_diff(output, output_ref):
|
|||||||
torch.abs(output_ref))
|
torch.abs(output_ref))
|
||||||
|
|
||||||
|
|
||||||
def torch_moe(a, w1, w2, score, topk):
|
def torch_moe(a, w1, w2, score, topk, expert_map):
|
||||||
B, D = a.shape
|
B, D = a.shape
|
||||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
@ -1061,6 +1061,8 @@ def torch_moe(a, w1, w2, score, topk):
|
|||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
topk_weight = topk_weight.view(-1)
|
topk_weight = topk_weight.view(-1)
|
||||||
topk_ids = topk_ids.view(-1)
|
topk_ids = topk_ids.view(-1)
|
||||||
|
if expert_map is not None:
|
||||||
|
topk_ids = expert_map[topk_ids]
|
||||||
for i in range(w1.shape[0]):
|
for i in range(w1.shape[0]):
|
||||||
mask = topk_ids == i
|
mask = topk_ids == i
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
|
|||||||
@ -297,12 +297,12 @@ def _test_completion_close(
|
|||||||
logprobs=5,
|
logprobs=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
logporbs = completion.choices[0].logprobs.top_logprobs[0]
|
logprobs = completion.choices[0].logprobs.top_logprobs[0]
|
||||||
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
|
logprobs = {k: round(v, 2) for k, v in logprobs.items()}
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test": "completion_close",
|
"test": "completion_close",
|
||||||
"logprobs": logporbs,
|
"logprobs": logprobs,
|
||||||
})
|
})
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@ -677,6 +677,23 @@ class ModelConfig:
|
|||||||
"fallback to the eager mode.")
|
"fallback to the eager mode.")
|
||||||
self.enforce_eager = True
|
self.enforce_eager = True
|
||||||
|
|
||||||
|
def _verify_with_expert_parallelism(self) -> None:
|
||||||
|
num_expert_names = [
|
||||||
|
"moe_num_experts", # Dbrx
|
||||||
|
"num_experts", # Jamba
|
||||||
|
"n_routed_experts", # DeepSeek
|
||||||
|
"num_local_experts", # Mixtral
|
||||||
|
]
|
||||||
|
num_experts = 0
|
||||||
|
for name in num_expert_names:
|
||||||
|
num_experts = getattr(self.hf_text_config, name, 0)
|
||||||
|
if num_experts > 0:
|
||||||
|
break
|
||||||
|
if num_experts < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Number of experts in the model must be greater than 0 "
|
||||||
|
"when expert parallelism is enabled.")
|
||||||
|
|
||||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||||
device_config) -> None:
|
device_config) -> None:
|
||||||
if not self.use_async_output_proc:
|
if not self.use_async_output_proc:
|
||||||
@ -730,6 +747,9 @@ class ModelConfig:
|
|||||||
" must be divisible by tensor parallel size "
|
" must be divisible by tensor parallel size "
|
||||||
f"({tensor_parallel_size}).")
|
f"({tensor_parallel_size}).")
|
||||||
|
|
||||||
|
if envs.VLLM_TEST_ENABLE_EP:
|
||||||
|
self._verify_with_expert_parallelism()
|
||||||
|
|
||||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||||
if pipeline_parallel_size > 1:
|
if pipeline_parallel_size > 1:
|
||||||
architectures = getattr(self.hf_config, "architectures", [])
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
|
|||||||
@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||||
|
VLLM_TEST_ENABLE_EP: bool = False
|
||||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||||
@ -570,6 +571,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# If set, vLLM will use the experimental expert parallel implementation on
|
||||||
|
# the FusedMoE layer, using tensor parallelism size as expert parallelism
|
||||||
|
# size.
|
||||||
|
"VLLM_TEST_ENABLE_EP":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),
|
||||||
|
|
||||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||||
# it allows ray to schedule multiple actors on a single GPU,
|
# it allows ray to schedule multiple actors on a single GPU,
|
||||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||||
|
|||||||
@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||||
|
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||||
|
compute_type):
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||||
|
None, :]
|
||||||
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_kernel_gptq_awq(
|
def fused_moe_kernel_gptq_awq(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq(
|
|||||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
token_mask = offs_token < num_valid_tokens
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||||
|
if off_experts == -1:
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Write back zeros to the output when the expert is not
|
||||||
|
# in the current expert parallel rank.
|
||||||
|
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||||
|
offs_token, token_mask, BLOCK_SIZE_M,
|
||||||
|
BLOCK_SIZE_N, compute_type)
|
||||||
|
return
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||||
offs_k[None, :] * stride_ak)
|
offs_k[None, :] * stride_ak)
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
||||||
|
|
||||||
if use_int4_w4a16:
|
if use_int4_w4a16:
|
||||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
|
||||||
|
stride_bn
|
||||||
b_shifter = (offs_k[:, None] % 2) * 4
|
b_shifter = (offs_k[:, None] % 2) * 4
|
||||||
elif use_int8_w8a16:
|
elif use_int8_w8a16:
|
||||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||||
@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq(
|
|||||||
|
|
||||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||||
offs_bn[None, :] * stride_bsn + \
|
offs_bn[None, :] * stride_bsn + \
|
||||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \
|
||||||
|
stride_bsk
|
||||||
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||||
b_scale = b_scale.to(tl.float32)
|
b_scale = b_scale.to(tl.float32)
|
||||||
|
|
||||||
@ -319,13 +341,22 @@ def fused_moe_kernel(
|
|||||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
token_mask = offs_token < num_valid_tokens
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||||
|
if off_experts == -1:
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Write back zeros to the output when the expert is not
|
||||||
|
# in the current expert parallel rank.
|
||||||
|
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||||
|
offs_token, token_mask, BLOCK_SIZE_M,
|
||||||
|
BLOCK_SIZE_N, compute_type)
|
||||||
|
return
|
||||||
|
|
||||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||||
offs_k[None, :] * stride_ak)
|
offs_k[None, :] * stride_ak)
|
||||||
|
|
||||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
|
||||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||||
offs_bn[None, :] * stride_bn)
|
offs_bn[None, :] * stride_bn)
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
@ -349,7 +380,6 @@ def fused_moe_kernel(
|
|||||||
# of fp32 values for higher accuracy.
|
# of fp32 values for higher accuracy.
|
||||||
# `accumulator` will be converted back to fp16 after the loop.
|
# `accumulator` will be converted back to fp16 after the loop.
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
# K dimension.
|
# K dimension.
|
||||||
@ -544,8 +574,11 @@ def moe_align_block_size_triton(
|
|||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
block_size: int,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: torch.Tensor = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Aligns the token distribution across experts to be compatible with block
|
Aligns the token distribution across experts to be compatible with block
|
||||||
size for matrix multiplication.
|
size for matrix multiplication.
|
||||||
@ -555,6 +588,10 @@ def moe_align_block_size(
|
|||||||
top-k expert indices for each token.
|
top-k expert indices for each token.
|
||||||
- block_size: The block size used in block matrix multiplication.
|
- block_size: The block size used in block matrix multiplication.
|
||||||
- num_experts: The total number of experts.
|
- num_experts: The total number of experts.
|
||||||
|
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||||
|
from the global space to the local index space of the current
|
||||||
|
expert parallel shard. If the expert is not in the current expert
|
||||||
|
parallel shard, the mapping is set to -1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||||
@ -589,7 +626,9 @@ def moe_align_block_size(
|
|||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||||
|
# mapping global expert ids to local expert ids in expert parallelism.
|
||||||
|
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device)
|
device=topk_ids.device)
|
||||||
num_tokens_post_pad = torch.empty((1),
|
num_tokens_post_pad = torch.empty((1),
|
||||||
@ -618,6 +657,9 @@ def moe_align_block_size(
|
|||||||
else:
|
else:
|
||||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||||
expert_ids, num_tokens_post_pad)
|
expert_ids, num_tokens_post_pad)
|
||||||
|
if expert_map is not None:
|
||||||
|
expert_ids = expert_map[expert_ids]
|
||||||
|
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None) -> None:
|
||||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||||
|
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts_fake(
|
def inplace_fused_experts_fake(
|
||||||
@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake(
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1049,6 +1096,8 @@ def outplace_fused_experts(
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1058,8 +1107,9 @@ def outplace_fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||||
False, use_fp8_w8a8, use_int8_w8a16,
|
False, use_fp8_w8a8, use_int8_w8a16,
|
||||||
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
use_int4_w4a16, global_num_experts, expert_map,
|
||||||
a1_scale, a2_scale, block_shape)
|
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||||
|
a2_scale, block_shape)
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts_fake(
|
def outplace_fused_experts_fake(
|
||||||
@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake(
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None):
|
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
torch.ops.vllm.inplace_fused_experts(
|
||||||
topk_weights, topk_ids,
|
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||||
use_fp8_w8a8, use_int8_w8a16,
|
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||||
use_int4_w4a16, w1_scale,
|
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||||
w2_scale, w1_zp, w2_zp, a1_scale,
|
|
||||||
a2_scale, block_shape)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.outplace_fused_experts(
|
return torch.ops.vllm.outplace_fused_experts(
|
||||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||||
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||||
a1_scale, a2_scale, block_shape)
|
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
num_tokens, _ = hidden_states.shape
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.shape
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
top_k_num = topk_ids.shape[1]
|
||||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
# https://github.com/vllm-project/vllm/issues/5938
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.shape,
|
||||||
w2.shape,
|
w2.shape,
|
||||||
topk_ids.shape[1],
|
top_k_num,
|
||||||
config_dtype,
|
config_dtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = get_config_func(M)
|
config = get_config_func(M)
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
intermediate_cache1 = torch.empty((M, top_k_num, N),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
|
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||||
|
global_num_experts, expert_map))
|
||||||
|
|
||||||
invoke_fused_moe_kernel(curr_hidden_states,
|
invoke_fused_moe_kernel(curr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
False,
|
False,
|
||||||
topk_ids.shape[1],
|
top_k_num,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
@ -1286,6 +1345,8 @@ def fused_moe(
|
|||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
@ -1320,6 +1381,11 @@ def fused_moe(
|
|||||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||||
activation to compute the inner products for w1 and w2.
|
activation to compute the inner products for w1 and w2.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
- global_num_experts (int): The total number of experts in the global
|
||||||
|
expert space.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
|
from the global expert space to the local expert space of the expert
|
||||||
|
parallel shard.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
w1.
|
w1.
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
@ -1334,8 +1400,6 @@ def fused_moe(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
# Check constraints.
|
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
|
||||||
|
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
@ -1358,6 +1422,8 @@ def fused_moe(
|
|||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
w1_zp=w1_zp,
|
w1_zp=w1_zp,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
@ -55,6 +56,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
@ -113,6 +116,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
@ -125,6 +130,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
@ -139,6 +146,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
@ -160,7 +169,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True)
|
inplace=True,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
@ -172,6 +183,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -196,6 +209,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||||
@ -215,6 +230,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
renormalize=renormalize)
|
renormalize=renormalize)
|
||||||
|
|
||||||
forward_native = forward_cuda
|
forward_native = forward_cuda
|
||||||
@ -255,6 +272,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
ep_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
@ -267,8 +285,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.tp_size = (tp_size if tp_size is not None else
|
self.tp_size = (tp_size if tp_size is not None else
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
|
if envs.VLLM_TEST_ENABLE_EP:
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
self.tp_size = 1
|
||||||
|
else:
|
||||||
|
self.ep_size = 1
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts # Global number of experts
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@ -281,6 +304,26 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.scoring_func = scoring_func
|
self.scoring_func = scoring_func
|
||||||
self.e_score_correction_bias = e_score_correction_bias
|
self.e_score_correction_bias = e_score_correction_bias
|
||||||
|
self.expert_map = None
|
||||||
|
|
||||||
|
if self.ep_size > 1:
|
||||||
|
# Create a tensor of size num_experts filled with -1
|
||||||
|
self.expert_map = torch.full((self.num_experts, ),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int32)
|
||||||
|
# Create a expert map for the local experts
|
||||||
|
local_num_experts = num_experts // self.ep_size
|
||||||
|
ep_rank = get_tensor_model_parallel_rank()
|
||||||
|
if ep_rank < (self.ep_size - 1):
|
||||||
|
# Each non-last rank gets local_num_experts experts.
|
||||||
|
self.expert_map[ep_rank * local_num_experts:
|
||||||
|
(ep_rank + 1) * local_num_experts] = \
|
||||||
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
# All remaining experts are assigned to the last rank.
|
||||||
|
local_num_experts = num_experts - ep_rank * local_num_experts
|
||||||
|
self.expert_map[-local_num_experts:] = \
|
||||||
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
@ -293,8 +336,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
local_num_experts = torch.sum(self.expert_map != -1) \
|
||||||
|
if self.expert_map is not None else num_experts
|
||||||
|
|
||||||
moe_quant_params = {
|
moe_quant_params = {
|
||||||
"num_experts": num_experts,
|
"num_experts": local_num_experts,
|
||||||
"hidden_size": hidden_size,
|
"hidden_size": hidden_size,
|
||||||
"intermediate_size_per_partition":
|
"intermediate_size_per_partition":
|
||||||
self.intermediate_size_per_partition,
|
self.intermediate_size_per_partition,
|
||||||
@ -423,10 +469,22 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert shard_id in ("w1", "w3")
|
assert shard_id in ("w1", "w3")
|
||||||
expert_data.copy_(loaded_weight)
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||||
|
if self.expert_map is None:
|
||||||
|
return expert_id
|
||||||
|
return self.expert_map[expert_id].item()
|
||||||
|
|
||||||
def weight_loader(self, param: torch.nn.Parameter,
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor, weight_name: str,
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
shard_id: str, expert_id: int) -> None:
|
shard_id: str, expert_id: int) -> None:
|
||||||
|
|
||||||
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||||
|
if expert_id == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# TP rank is set to 0 if EP is enabled
|
||||||
|
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
# compressed-tensors checkpoints with packed weights are stored flipped
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
||||||
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
||||||
# against known CompressionFormat enum values that have this quality
|
# against known CompressionFormat enum values that have this quality
|
||||||
@ -447,7 +505,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||||
|
|
||||||
expert_data = param.data[expert_id]
|
expert_data = param.data[expert_id]
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
# is_transposed: if the dim to shard the weight
|
# is_transposed: if the dim to shard the weight
|
||||||
# should be flipped. Required by GPTQ, compressed-tensors
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
@ -590,13 +647,16 @@ class FusedMoE(torch.nn.Module):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
global_num_experts=self.num_experts,
|
||||||
|
expert_map=self.expert_map,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
scoring_func=self.scoring_func,
|
scoring_func=self.scoring_func,
|
||||||
e_score_correction_bias=self.e_score_correction_bias)
|
e_score_correction_bias=self.e_score_correction_bias)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
|
# Default set to False. (May have to add shared expert outputs.)
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,9 @@ def fused_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
global_num_experts: int,
|
||||||
|
expert_map: torch.Tensor = None,
|
||||||
|
renormalize: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -18,6 +20,7 @@ def fused_moe(
|
|||||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||||
w2: [num_experts, hidden_size, intermediate_size]
|
w2: [num_experts, hidden_size, intermediate_size]
|
||||||
gating_output: [*, num_experts]
|
gating_output: [*, num_experts]
|
||||||
|
expert_map: [num_experts]
|
||||||
"""
|
"""
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_size = hidden_states.shape[-1]
|
hidden_size = hidden_states.shape[-1]
|
||||||
@ -27,13 +30,16 @@ def fused_moe(
|
|||||||
dtype = hidden_states.dtype
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||||
gating_output = gating_output.view(num_tokens, num_experts)
|
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
topk_weights = topk_weights.to(dtype)
|
topk_weights = topk_weights.to(dtype)
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
selected_experts = expert_map[selected_experts]
|
||||||
|
|
||||||
final_hidden_states = None
|
final_hidden_states = None
|
||||||
for expert_idx in range(num_experts):
|
for expert_idx in range(num_experts):
|
||||||
expert_w1 = w1[expert_idx]
|
expert_w1 = w1[expert_idx]
|
||||||
|
|||||||
@ -464,10 +464,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if expert_map is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Expert Parallelism is not supported for "
|
||||||
|
"fused Marlin MoE method.")
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@ -214,6 +214,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
@ -239,6 +241,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
@ -540,10 +544,16 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if expert_map is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Expert Parallelism is not supported for "
|
||||||
|
"fused Marlin MoE method.")
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@ -108,6 +108,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
@ -133,6 +135,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_int8_w8a16=True,
|
use_int8_w8a16=True,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_scale,
|
w1_scale=layer.w13_scale,
|
||||||
w2_scale=layer.w2_scale)
|
w2_scale=layer.w2_scale)
|
||||||
|
|
||||||
|
|||||||
@ -670,6 +670,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
@ -697,6 +699,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=(layer.w13_weight_scale_inv
|
w1_scale=(layer.w13_weight_scale_inv
|
||||||
if self.block_quant else layer.w13_weight_scale),
|
if self.block_quant else layer.w13_weight_scale),
|
||||||
w2_scale=(layer.w2_weight_scale_inv
|
w2_scale=(layer.w2_weight_scale_inv
|
||||||
|
|||||||
@ -585,6 +585,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@ -288,6 +288,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
@ -317,6 +319,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
use_int4_w4a16=weight_bits == 4,
|
||||||
use_int8_w8a16=weight_bits == 8,
|
use_int8_w8a16=weight_bits == 8,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_scales,
|
w1_scale=layer.w13_scales,
|
||||||
w2_scale=layer.w2_scales,
|
w2_scale=layer.w2_scales,
|
||||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||||
|
|||||||
@ -198,6 +198,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
@ -223,6 +225,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
|
|||||||
@ -106,10 +106,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.n_shared_experts = config.n_shared_experts
|
self.n_shared_experts = config.n_shared_experts
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
if self.tp_size > config.n_routed_experts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {config.n_routed_experts}.")
|
|
||||||
|
|
||||||
if config.hidden_act != "silu":
|
if config.hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user