mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:15:34 +08:00
[torch.compile] support moe models (#9632)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4e2d95e372
commit
32176fee73
@ -88,6 +88,8 @@ def benchmark_config(
|
|||||||
input_gating.copy_(gating_output[i])
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
|
with override_config(config):
|
||||||
fused_moe(
|
fused_moe(
|
||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
@ -96,7 +98,6 @@ def benchmark_config(
|
|||||||
topk,
|
topk,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
override_config=config,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
|
|||||||
@ -13,11 +13,11 @@ from ..utils import compare_all_settings
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
|
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
|
||||||
[
|
[
|
||||||
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
|
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
|
||||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
|
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
|
||||||
["--quantization", "compressed-tensors"
|
["--quantization", "compressed-tensors"
|
||||||
], 1, 1, "FLASH_ATTN", "generate", True),
|
], 1, 1, "FLASH_ATTN", "generate", True),
|
||||||
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
|
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
|
||||||
# TODO: add multi-modality test for llava
|
# TODO: add multi-modality test for llava
|
||||||
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
|
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
|
||||||
])
|
])
|
||||||
|
|||||||
@ -5,11 +5,10 @@ Run `pytest tests/kernels/test_awq_marlin.py`.
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||||
torch_moe_single)
|
torch_moe_single)
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
||||||
fused_marlin_moe, single_marlin_moe)
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
awq_marlin_quantize)
|
awq_marlin_quantize)
|
||||||
@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
|
|||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||||
marlin_output = fused_marlin_moe(
|
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||||
a,
|
a,
|
||||||
qweight1,
|
qweight1,
|
||||||
qweight2,
|
qweight2,
|
||||||
@ -150,7 +149,7 @@ def test_single_marlin_moe_multiply_awq(
|
|||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
marlin_output = single_marlin_moe(a,
|
marlin_output = torch.ops.vllm.single_marlin_moe(a,
|
||||||
qweight,
|
qweight,
|
||||||
scales,
|
scales,
|
||||||
score,
|
score,
|
||||||
|
|||||||
@ -7,12 +7,11 @@ import torch
|
|||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
|
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
|
||||||
torch_moe, torch_moe_single)
|
torch_moe, torch_moe_single)
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
||||||
fused_marlin_moe, single_marlin_moe)
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, moe_align_block_size)
|
fused_topk, moe_align_block_size)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
@ -193,7 +192,7 @@ def test_fused_marlin_moe(
|
|||||||
topk,
|
topk,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
)
|
)
|
||||||
marlin_output = fused_marlin_moe(
|
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||||
a,
|
a,
|
||||||
qweight1,
|
qweight1,
|
||||||
qweight2,
|
qweight2,
|
||||||
@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
|
|||||||
sort_indices = stack_and_dev(sort_indices_l)
|
sort_indices = stack_and_dev(sort_indices_l)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
marlin_output = single_marlin_moe(
|
marlin_output = torch.ops.vllm.single_marlin_moe(
|
||||||
a,
|
a,
|
||||||
qweight,
|
qweight,
|
||||||
scales,
|
scales,
|
||||||
|
|||||||
@ -1,23 +1,43 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def override_config(config):
|
||||||
|
global _config
|
||||||
|
old_config = _config
|
||||||
|
_config = config
|
||||||
|
yield
|
||||||
|
_config = old_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> Optional[Dict[str, Any]]:
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FusedMoE",
|
"FusedMoE",
|
||||||
"FusedMoEMethodBase",
|
"FusedMoEMethodBase",
|
||||||
"FusedMoeWeightScaleSupported",
|
"FusedMoeWeightScaleSupported",
|
||||||
|
"override_config",
|
||||||
|
"get_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
# import to register the custom ops
|
||||||
fused_marlin_moe, single_marlin_moe)
|
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||||
|
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||||
grouped_topk)
|
grouped_topk)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_marlin_moe",
|
|
||||||
"single_marlin_moe",
|
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
"fused_topk",
|
"fused_topk",
|
||||||
"fused_experts",
|
"fused_experts",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Fused MoE utilities for GPTQ."""
|
"""Fused MoE utilities for GPTQ."""
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
|
|||||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
|
||||||
def single_marlin_moe(
|
def single_marlin_moe(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w: torch.Tensor,
|
w: torch.Tensor,
|
||||||
@ -28,7 +29,6 @@ def single_marlin_moe(
|
|||||||
g_idx: Optional[torch.Tensor] = None,
|
g_idx: Optional[torch.Tensor] = None,
|
||||||
sort_indices: Optional[torch.Tensor] = None,
|
sort_indices: Optional[torch.Tensor] = None,
|
||||||
w_zeros: Optional[torch.Tensor] = None,
|
w_zeros: Optional[torch.Tensor] = None,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -49,8 +49,6 @@ def single_marlin_moe(
|
|||||||
- topk (int): The number of top-k experts to select.
|
- topk (int): The number of top-k experts to select.
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
|
||||||
for the kernel configuration.
|
|
||||||
- num_bits (bool): The number of bits in expert weights quantization.
|
- num_bits (bool): The number of bits in expert weights quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -79,7 +77,6 @@ def single_marlin_moe(
|
|||||||
w.shape,
|
w.shape,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
None,
|
None,
|
||||||
override_config=override_config,
|
|
||||||
is_marlin=True)
|
is_marlin=True)
|
||||||
config = get_config_func(M)
|
config = get_config_func(M)
|
||||||
|
|
||||||
@ -122,6 +119,24 @@ def single_marlin_moe(
|
|||||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@single_marlin_moe.register_fake
|
||||||
|
def _(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
g_idx: Optional[torch.Tensor] = None,
|
||||||
|
sort_indices: Optional[torch.Tensor] = None,
|
||||||
|
w_zeros: Optional[torch.Tensor] = None,
|
||||||
|
num_bits: int = 8,
|
||||||
|
is_k_full: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
|
||||||
def fused_marlin_moe(
|
def fused_marlin_moe(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -137,7 +152,6 @@ def fused_marlin_moe(
|
|||||||
sort_indices2: Optional[torch.Tensor] = None,
|
sort_indices2: Optional[torch.Tensor] = None,
|
||||||
w1_zeros: Optional[torch.Tensor] = None,
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
w2_zeros: Optional[torch.Tensor] = None,
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
is_k_full: bool = True,
|
is_k_full: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -161,8 +175,6 @@ def fused_marlin_moe(
|
|||||||
permutation.
|
permutation.
|
||||||
- topk_weights (torch.Tensor): Top-k weights.
|
- topk_weights (torch.Tensor): Top-k weights.
|
||||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
|
||||||
for the kernel configuration.
|
|
||||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||||
- num_bits (bool): The number of bits in expert weights quantization.
|
- num_bits (bool): The number of bits in expert weights quantization.
|
||||||
@ -209,7 +221,6 @@ def fused_marlin_moe(
|
|||||||
w2.shape,
|
w2.shape,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
None,
|
None,
|
||||||
override_config=override_config,
|
|
||||||
is_marlin=True,
|
is_marlin=True,
|
||||||
)
|
)
|
||||||
config = get_config_func(M)
|
config = get_config_func(M)
|
||||||
@ -311,3 +322,25 @@ def fused_marlin_moe(
|
|||||||
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1)
|
dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@fused_marlin_moe.register_fake
|
||||||
|
def _(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
g_idx1: Optional[torch.Tensor] = None,
|
||||||
|
g_idx2: Optional[torch.Tensor] = None,
|
||||||
|
sort_indices1: Optional[torch.Tensor] = None,
|
||||||
|
sort_indices2: Optional[torch.Tensor] = None,
|
||||||
|
w1_zeros: Optional[torch.Tensor] = None,
|
||||||
|
w2_zeros: Optional[torch.Tensor] = None,
|
||||||
|
num_bits: int = 8,
|
||||||
|
is_k_full: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|||||||
@ -358,9 +358,10 @@ def try_get_optimal_moe_config(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
M: int,
|
M: int,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
is_marlin: bool = False,
|
is_marlin: bool = False,
|
||||||
):
|
):
|
||||||
|
from vllm.model_executor.layers.fused_moe import get_config
|
||||||
|
override_config = get_config()
|
||||||
if override_config:
|
if override_config:
|
||||||
config = override_config
|
config = override_config
|
||||||
else:
|
else:
|
||||||
@ -465,13 +466,103 @@ def get_config_dtype_str(dtype: torch.dtype,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::inplace_fused_experts",
|
||||||
|
mutates_args=["hidden_states"])
|
||||||
|
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||||
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
|
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
|
||||||
|
a1_scale, a2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
@inplace_fused_experts.register_fake
|
||||||
|
def _(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
|
||||||
|
def outplace_fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||||
|
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
|
||||||
|
w2_scale, a1_scale, a2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
@outplace_fused_experts.register_fake
|
||||||
|
def _(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(hidden_states: torch.Tensor,
|
def fused_experts(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None):
|
||||||
|
if inplace:
|
||||||
|
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||||
|
topk_weights, topk_ids,
|
||||||
|
use_fp8_w8a8, use_int8_w8a16,
|
||||||
|
w1_scale, w2_scale, a1_scale,
|
||||||
|
a2_scale)
|
||||||
|
return hidden_states
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.outplace_fused_experts(
|
||||||
|
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||||
|
use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
@ -504,7 +595,6 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2.shape,
|
w2.shape,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
config_dtype,
|
config_dtype,
|
||||||
override_config=override_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = get_config_func(M)
|
config = get_config_func(M)
|
||||||
@ -602,7 +692,6 @@ def fused_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
@ -628,8 +717,6 @@ def fused_moe(
|
|||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
|
||||||
for the kernel configuration.
|
|
||||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
@ -667,7 +754,6 @@ def fused_moe(
|
|||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
override_config=override_config,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
|
|||||||
@ -12,7 +12,16 @@ from vllm.model_executor.custom_op import CustomOp
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from .fused_moe import fused_experts
|
||||||
|
else:
|
||||||
|
fused_experts = None # type: ignore
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||||
|
else:
|
||||||
|
fused_moe_pallas = None # type: ignore
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -96,9 +105,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
fused_experts)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -132,18 +138,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None
|
custom_routing_function: Optional[Callable] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
|
|
||||||
assert not use_grouped_topk
|
assert not use_grouped_topk
|
||||||
assert num_expert_group is None
|
assert num_expert_group is None
|
||||||
assert topk_group is None
|
assert topk_group is None
|
||||||
assert custom_routing_function is None
|
assert custom_routing_function is None
|
||||||
return fused_moe(hidden_states=x,
|
return fused_moe_pallas(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
renormalize=renormalize)
|
renormalize=renormalize)
|
||||||
|
|
||||||
|
forward_native = forward_cuda
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
@ -435,10 +436,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
||||||
fused_marlin_moe)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -449,7 +446,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function)
|
||||||
|
|
||||||
return fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
layer.w2_qweight,
|
layer.w2_qweight,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
@ -481,10 +482,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
||||||
fused_marlin_moe)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -495,7 +492,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function)
|
custom_routing_function=custom_routing_function)
|
||||||
|
|
||||||
return fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w2_weight_packed,
|
layer.w2_weight_packed,
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
@ -536,9 +537,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
||||||
fused_marlin_moe)
|
|
||||||
|
|
||||||
# The input must currently be float16
|
# The input must currently be float16
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.half()
|
x = x.half()
|
||||||
@ -553,7 +551,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=None)
|
custom_routing_function=None)
|
||||||
|
|
||||||
return fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
layer.w2_qweight,
|
layer.w2_qweight,
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
from transformers.models.granitemoe import GraniteMoeConfig
|
from transformers.models.granitemoe import GraniteMoeConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -244,6 +245,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class GraniteMoeModel(nn.Module):
|
class GraniteMoeModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user