[torch.compile] support moe models (#9632)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-10-27 21:58:04 -07:00 committed by GitHub
parent 4e2d95e372
commit 32176fee73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 216 additions and 77 deletions

View File

@ -88,22 +88,23 @@ def benchmark_config(
input_gating.copy_(gating_output[i])
def run():
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
override_config=config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
# JIT compilation & warmup
run()

View File

@ -13,11 +13,11 @@ from ..utils import compare_all_settings
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])

View File

@ -5,11 +5,10 @@ Run `pytest tests/kernels/test_awq_marlin.py`.
import pytest
import torch
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
@ -150,14 +149,14 @@ def test_single_marlin_moe_multiply_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
marlin_output = torch.ops.vllm.single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

View File

@ -7,12 +7,11 @@ import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@ -193,7 +192,7 @@ def test_fused_marlin_moe(
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,

View File

@ -1,23 +1,43 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
_config: Optional[Dict[str, Any]] = None
@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config
def get_config() -> Optional[Dict[str, Any]]:
return _config
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
]
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",

View File

@ -1,6 +1,6 @@
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
from typing import Optional
import torch
@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
@ -28,7 +29,6 @@ def single_marlin_moe(
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
@ -49,8 +49,6 @@ def single_marlin_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
@ -79,7 +77,6 @@ def single_marlin_moe(
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
@ -122,6 +119,24 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@single_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@ -137,7 +152,6 @@ def fused_marlin_moe(
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
@ -161,8 +175,6 @@ def fused_marlin_moe(
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
@ -209,7 +221,6 @@ def fused_marlin_moe(
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
@ -311,3 +322,25 @@ def fused_marlin_moe(
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
@fused_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

View File

@ -358,9 +358,10 @@ def try_get_optimal_moe_config(
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
is_marlin: bool = False,
):
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
if override_config:
config = override_config
else:
@ -465,19 +466,109 @@ def get_config_dtype_str(dtype: torch.dtype,
return None
@torch.library.custom_op("vllm::inplace_fused_experts",
mutates_args=["hidden_states"])
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
a1_scale, a2_scale)
@inplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> None:
pass
@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[])
def outplace_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
w2_scale, a1_scale, a2_scale)
@outplace_fused_experts.register_fake
def _(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16,
w1_scale, w2_scale, a1_scale,
a2_scale)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale)
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
@ -504,7 +595,6 @@ def fused_experts(hidden_states: torch.Tensor,
w2.shape,
topk_ids.shape[1],
config_dtype,
override_config=override_config,
)
config = get_config_func(M)
@ -602,7 +692,6 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
@ -628,8 +717,6 @@ def fused_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
@ -667,7 +754,6 @@ def fused_moe(
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,

View File

@ -12,7 +12,16 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
@ -96,9 +105,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -132,17 +138,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
forward_native = forward_cuda
class FusedMoE(torch.nn.Module):

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
@ -435,10 +436,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -449,7 +446,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,

View File

@ -6,6 +6,7 @@ import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
@ -481,10 +482,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -495,7 +492,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
@ -536,9 +537,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
@ -553,7 +551,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=None)
return fused_marlin_moe(
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -244,6 +245,7 @@ class GraniteMoeDecoderLayer(nn.Module):
return hidden_states
@support_torch_compile
class GraniteMoeModel(nn.Module):
def __init__(