[Model] Adding support for MSFT Phi-3.5-MoE (#7729)

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Zeqi Lin <zelin@microsoft.com>
Co-authored-by: Zeqi Lin <Zeqi.Lin@microsoft.com>
This commit is contained in:
Wenxiang 2024-08-31 03:42:57 +08:00 committed by GitHub
parent 2684efc467
commit 1248e8506a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1254 additions and 81 deletions

View File

@ -147,6 +147,10 @@ Decoder-only Language Models
- Phi-3-Small - Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
- -
* - :code:`PhiMoEForCausalLM`
- Phi-3.5-MoE
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
-
* - :code:`PersimmonForCausalLM` * - :code:`PersimmonForCausalLM`
- Persimmon - Persimmon
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.

111
tests/models/test_phimoe.py Normal file
View File

@ -0,0 +1,111 @@
"""Compare the outputs of HF and vLLM for moe models using greedy sampling.
Run `pytest tests/models/test_phimoe.py`.
"""
import pytest
import torch
from vllm.utils import is_cpu
from .utils import check_logprobs_close
MODELS = [
"microsoft/Phi-3.5-MoE-instruct",
]
def test_phimoe_routing_function():
from vllm.model_executor.models.phimoe import phimoe_routing_function
test_case = {
0: {
"hidden_states":
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.float32,
requires_grad=False).view(4, 2),
"gating_output":
torch.tensor([0.1, 0.2, 0.3, 0.4],
dtype=torch.float32,
requires_grad=False),
"topk":
2,
"renormalize":
False,
},
1: {
"hidden_states":
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.float32,
requires_grad=False).view(4, 2),
"gating_output":
torch.tensor([0.4, 0.2, 0.3, 0.4],
dtype=torch.float32,
requires_grad=False),
"topk":
2,
"renormalize":
False,
}
}
ground_truth = {
0: {
"topk_weights":
torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False),
"topk_ids":
torch.tensor([3, 2], dtype=torch.long, requires_grad=False),
},
1: {
"topk_weights":
torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False),
"topk_ids":
torch.tensor([0, 3], dtype=torch.long, requires_grad=False),
}
}
for test_id in test_case:
topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id])
assert torch.allclose(topk_weights,
ground_truth[test_id]["topk_weights"])
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
def get_gpu_memory():
try:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
gpu_memory = props.total_memory / (1024**3)
return gpu_memory
except Exception:
return 0
@pytest.mark.skipif(condition=is_cpu(),
reason="This test takes a lot time to run on CPU, "
"and vllm CI's disk space is not enough for this model.")
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
reason="Skip this test if GPU memory is insufficient.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -0,0 +1,130 @@
{
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}

View File

@ -0,0 +1,130 @@
{
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}

View File

@ -0,0 +1,130 @@
{
"2048": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3584": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import torch import torch
import triton import triton
@ -446,7 +446,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
rand_perm1: torch.Tensor, rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor, rand_perm2: torch.Tensor,
topk: int, topk: int,
renormalize: bool, custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
@ -497,8 +498,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, if custom_routing_function is None:
renormalize) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
get_config_func = functools.partial(try_get_optimal_moe_config, get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape, w1.shape,
@ -695,6 +700,7 @@ def fused_moe(
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
@ -742,9 +748,12 @@ def fused_moe(
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize, topk, renormalize,
num_expert_group, topk_group) num_expert_group, topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize) renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states, return fused_experts(hidden_states,
w1, w1,

View File

@ -1,6 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
@ -62,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool, top_k: int,
use_grouped_topk: bool, renormalize: bool,
topk_group: Optional[int] = None, use_grouped_topk: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
@ -79,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(self, def forward_cuda(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
top_k: int, use_grouped_topk: bool,
router_logits: torch.Tensor, top_k: int,
renormalize: bool, router_logits: torch.Tensor,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
@ -101,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@ -114,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError( raise NotImplementedError(
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu(self, def forward_tpu(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
top_k: int, use_grouped_topk: bool,
router_logits: torch.Tensor, top_k: int,
renormalize: bool, router_logits: torch.Tensor,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x, return fused_moe(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
@ -172,6 +184,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None,
): ):
super().__init__() super().__init__()
@ -190,6 +203,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
@ -390,7 +404,8 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None): num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk)
@ -405,11 +420,17 @@ class FusedMoE(torch.nn.Module):
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids return topk_weights, topk_ids
@ -426,7 +447,8 @@ class FusedMoE(torch.nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group) num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(

View File

@ -1,6 +1,6 @@
import enum import enum
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Callable, List, Optional
import torch import torch
@ -256,15 +256,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
) )
replace_tensor("w2_weight_scale", marlin_w2_scales) replace_tensor("w2_weight_scale", marlin_w2_scales)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool = True, top_k: int,
use_grouped_topk: bool = False, renormalize: bool = True,
num_expert_group: Optional[int] = None, use_grouped_topk: bool = False,
topk_group: Optional[int] = None) -> torch.Tensor: num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_marlin_moe) fused_marlin_moe)
@ -278,6 +281,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
top_k, top_k,
custom_routing_function,
renormalize=renormalize, renormalize=renormalize,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale) w2_scale=layer.w2_weight_scale)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
@ -96,15 +96,18 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_scale", w2_scale) layer.register_parameter("w2_scale", w2_scale)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool = True, top_k: int,
use_grouped_topk: bool = False, renormalize: bool = True,
num_expert_group: Optional[int] = None, use_grouped_topk: bool = False,
topk_group: Optional[int] = None) -> torch.Tensor: num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
@ -114,7 +117,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
@ -468,15 +468,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
return return
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool, top_k: int,
use_grouped_topk: bool, renormalize: bool,
topk_group: Optional[int] = None, use_grouped_topk: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
@ -487,7 +490,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,

View File

@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
short_factor: List[float], short_factor: List[float],
long_factor: List[float], long_factor: List[float],
short_mscale: float = 1.0, short_mscale: Optional[float] = None,
long_mscale: float = 1.0, long_mscale: Optional[float] = None,
): ):
super().__init__() super().__init__()
@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.base = base self.base = base
self.short_factor = short_factor self.short_factor = short_factor
self.long_factor = long_factor self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
scale = (self.max_position_embeddings /
self.original_max_position_embeddings)
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
self.scaling_factor = 1.0 scaling_factor = 1.0
else: else:
self.scaling_factor = math.sqrt( scaling_factor = math.sqrt(
1 + math.log(scale) / 1 + math.log(scale) /
math.log(self.original_max_position_embeddings)) math.log(self.original_max_position_embeddings))
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache( short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale) original_max_position_embeddings, short_factor, short_mscale)
@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
inv_freq = self._compute_inv_freq(rescale_factors) inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float) t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale * self.scaling_factor cos = freqs.cos() * mscale
sin = freqs.sin() * mscale * self.scaling_factor sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache

View File

@ -50,6 +50,7 @@ _GENERATION_MODELS = {
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),

View File

@ -0,0 +1,620 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only PhiMoE model."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
class PhiMoEConfig(PretrainedConfig):
model_type = "phimoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
attention_bias=False,
lm_head_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class mp(torch.autograd.Function):
@staticmethod
def forward(
ctx,
scores: torch.Tensor,
multiplier: torch.Tensor,
selected_experts: torch.Tensor,
masked_gates: torch.Tensor,
mask_for_one: torch.Tensor,
):
ctx.save_for_backward(multiplier, selected_experts, masked_gates)
return multiplier * mask_for_one
@staticmethod
def backward(
ctx,
grad_at_output: torch.Tensor,
):
multiplier, selected_experts, masked_gates = ctx.saved_tensors
grad_at_output = grad_at_output * multiplier
grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
grad_at_scores_expaned.scatter_add_(
dim=-1,
index=selected_experts,
src=grad_at_output,
)
return (
grad_at_scores_expaned,
None,
None,
None,
None,
)
def sparsemixer(scores, jitter_eps=0.01):
################ first expert ################
with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
# apply mask
masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf"))
selected_experts = max_ind
# compute scores for gradients
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
multiplier = multiplier_o
# masked out first expert
masked_scores = torch.scatter(
scores,
-1,
selected_experts,
float("-inf"),
)
with torch.no_grad():
# compute mask for sparsity
mask_logits_threshold, max_ind = masked_scores.max(dim=-1,
keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold)
mask_logits_threshold = (
(mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
# apply mask
masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold,
float("-inf"))
selected_experts_top2 = max_ind
# compute scores for gradients
masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
multiplier_top2 = masked_gates_top2.gather(dim=-1,
index=selected_experts_top2)
multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
selected_experts = torch.concat((selected_experts, selected_experts_top2),
dim=-1)
return (
multiplier,
selected_experts,
)
def phimoe_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert topk == 2, "Only top-2 routing is supported"
assert renormalize is False, "Renormalization is not supported"
topk_weights, topk_ids = sparsemixer(gating_output)
return topk_weights, topk_ids
class PhiMoE(nn.Module):
"""A tensor-parallel MoE implementation for PhiMoE that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
params_dtype=params_dtype,
quant_config=None,
)
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
custom_routing_function=phimoe_routing_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)
class PhiMoEAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[dict] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=None,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
quant_config=None,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
rope_scaling=self.rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
class PhiMoEDecoderLayer(nn.Module):
def __init__(
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = PhiMoEAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=config.rope_scaling,
)
self.block_sparse_moe = PhiMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=True)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=True)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
residual = hidden_states
# Self Attention
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + residual
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = hidden_states + residual
return hidden_states, residual
class PhiMoEModel(nn.Module):
def __init__(
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
hidden_states = self.norm(hidden_states)
return hidden_states
class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = PhiMoEModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=None,
bias=True,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)