mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Modularize fused experts and integrate PPLX kernels (#15956)
This commit is contained in:
parent
418d2f8bfb
commit
f9c069c85e
@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
if (num_tokens == 0) { \
|
||||
return; \
|
||||
} \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
|
||||
@ -65,5 +65,19 @@
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
||||
|
||||
@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
} else if (use_i16) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// set dynamic shared mem
|
||||
auto kernel =
|
||||
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
topk_ids.numel());
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
TORCH_CHECK(num_experts == 256,
|
||||
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
auto options_int =
|
||||
|
||||
@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
template <int TPB, typename IndType>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(
|
||||
const float* inputs_after_softmax,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
IndType* indices,
|
||||
int* source_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
@ -397,8 +405,8 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
IndType* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
@ -493,6 +502,9 @@ void topk_softmax(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
|
||||
if(topk_indices.scalar_type() == at::ScalarType::Int)
|
||||
{
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
@ -504,3 +516,18 @@ void topk_softmax(
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<uint32_t>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,11 +65,17 @@ def parse_args():
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
parser.add_argument("--enforce-eager",
|
||||
action='store_true',
|
||||
help="Enforce eager mode execution.")
|
||||
parser.add_argument("--trust-remote-code",
|
||||
action='store_true',
|
||||
help="Trust remote code.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank):
|
||||
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
max_tokens=[16, 20][global_dp_rank % 2])
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model=model,
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True)
|
||||
enforce_eager=enforce_eager,
|
||||
enable_expert_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for i, output in enumerate(outputs):
|
||||
@ -155,7 +164,8 @@ if __name__ == "__main__":
|
||||
proc = Process(target=main,
|
||||
args=(args.model, dp_size, local_dp_rank,
|
||||
global_dp_rank, dp_master_ip, dp_master_port,
|
||||
tp_size))
|
||||
tp_size, args.enforce_eager,
|
||||
args.trust_remote_code))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
|
||||
114
tests/kernels/moe/test_batched_moe.py
Normal file
114
tests/kernels/moe/test_batched_moe.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
K: int
|
||||
N: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMTensors:
|
||||
A: torch.Tensor # [E, max_tokens, K]
|
||||
B: torch.Tensor # [E, K, N] - column major
|
||||
C: torch.Tensor # [E, max_tokens, N]
|
||||
num_expert_tokens: torch.Tensor # [E]
|
||||
|
||||
@staticmethod
|
||||
def make_tensors(config: BatchedMMConfig):
|
||||
A = torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype) / 10
|
||||
B = torch.randn((config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||
[32, 64, 128, 192, 224, 256, 512])
|
||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
N: int, dtype: torch.dtype):
|
||||
|
||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
||||
tensors = BatchedMMTensors.make_tensors(config)
|
||||
|
||||
test_output = tensors.C
|
||||
ref_output = test_output.clone()
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
}[test_output.dtype]
|
||||
invoke_moe_batched_triton_kernel(
|
||||
tensors.A,
|
||||
tensors.B,
|
||||
test_output,
|
||||
tensors.num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
# Quantization schemes
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16
|
||||
})
|
||||
|
||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
||||
tensors.num_expert_tokens)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
||||
@ -30,6 +30,11 @@ MNK_FACTORS = [
|
||||
(224, 3072, 1536),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids_': topk_ids,
|
||||
'topk_ids': topk_ids,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
@ -231,15 +236,12 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
@ -276,17 +278,14 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
@ -334,15 +333,12 @@ def test_cutlass_moe_8_bit_EP(
|
||||
ep_size: int,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@ -70,6 +75,7 @@ def test_fused_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
@ -95,6 +101,7 @@ def test_fused_moe(
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
@ -115,7 +122,6 @@ def test_fused_moe(
|
||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||
has_zp: bool, weight_bits: int):
|
||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@ -194,6 +200,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
@ -210,6 +217,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@ -515,6 +523,7 @@ def test_fused_marlin_moe(
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
|
||||
691
tests/kernels/moe/test_pplx_moe.py
Normal file
691
tests/kernels/moe/test_pplx_moe.py
Normal file
@ -0,0 +1,691 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_finalize, nvshmem_get_unique_id,
|
||||
nvshmem_init)
|
||||
has_pplx = True
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||
(222, 2048, 1024)]
|
||||
|
||||
PPLX_MOE_COMBOS = [
|
||||
(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
(32, 128, 1024),
|
||||
(45, 512, 2048),
|
||||
(64, 1024, 1024),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
"tcp://localhost:29500",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
def parallel_launch_from_env(
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Launches a worker function in parallel across all processes in the current
|
||||
environment. The environment must have the following variables set:
|
||||
- WORLD_SIZE: The total number of processes.
|
||||
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
||||
- NODE_RANK: The rank of the current
|
||||
- MASTER_ADDR: The address of the master process.
|
||||
- MASTER_PORT: The port of the master process.
|
||||
"""
|
||||
assert not kwargs
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
||||
node_rank = int(os.environ["NODE_RANK"])
|
||||
assert "MASTER_ADDR" in os.environ
|
||||
assert "MASTER_PORT" in os.environ
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_local_size,
|
||||
node_rank,
|
||||
"env://",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_local_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
def torch_prepare(
|
||||
a: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.shape[0] == a.shape[0]
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||
minlength=num_experts)
|
||||
|
||||
assert tokens_per_expert.numel() == num_experts
|
||||
|
||||
if max_num_tokens is None:
|
||||
max_num_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
|
||||
dtype=a.dtype,
|
||||
device=a.device)
|
||||
|
||||
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
|
||||
|
||||
for token in range(num_tokens):
|
||||
for j in range(topk):
|
||||
expert_id = topk_ids[token, j]
|
||||
idx = token_counts[expert_id]
|
||||
b_a[expert_id, idx:idx + 1, :] = a[token, :]
|
||||
token_counts[expert_id] = token_counts[expert_id] + 1
|
||||
|
||||
return b_a, tokens_per_expert
|
||||
|
||||
|
||||
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = b_out.shape[0]
|
||||
K = b_out.shape[-1]
|
||||
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
|
||||
expert_counts = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=b_out.device)
|
||||
for token in range(num_tokens):
|
||||
expert_ids = topk_ids[token]
|
||||
for i in range(expert_ids.numel()):
|
||||
expert_id = expert_ids[i]
|
||||
idx = expert_counts[expert_id]
|
||||
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
|
||||
1, :] * topk_weight[token, i]
|
||||
expert_counts[expert_id] = expert_counts[expert_id] + 1
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def torch_batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_experts = w1.shape[0]
|
||||
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
|
||||
assert b_a.dim() == 3
|
||||
num_tokens, topk = topk_ids.shape
|
||||
_, max_num_tokens, K = b_a.shape
|
||||
assert num_experts == b_a.shape[0] and w2.shape[1] == K
|
||||
out = torch.zeros((num_experts, max_num_tokens, K),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
|
||||
dtype=b_a.dtype,
|
||||
device=b_a.device)
|
||||
for expert in range(num_experts):
|
||||
num = tokens_per_expert[expert]
|
||||
if num > 0:
|
||||
torch.ops._C.silu_and_mul(
|
||||
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
|
||||
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
|
||||
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
|
||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
|
||||
|
||||
# Note: same as torch_moe but with fused_topk factored out.
|
||||
def torch_moe2(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
num_experts = w1.shape[0]
|
||||
for i in range(num_experts):
|
||||
mask = (topk_ids == i).view(-1)
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
|
||||
|
||||
def rank_chunk(num: int, r: int, w: int) -> int:
|
||||
rem = num % w
|
||||
return (num // w) + (1 if r < rem else 0)
|
||||
|
||||
|
||||
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||
chunk = rank_chunk(t.shape[0], r, w)
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
|
||||
|
||||
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
|
||||
num_experts: int) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
num_tokens, hidden_dim = a.shape
|
||||
block_size = 128
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||
|
||||
ata = AllToAll.internode(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
)
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
ata,
|
||||
max_num_tokens,
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
a.dtype,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||
|
||||
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
|
||||
a_chunk,
|
||||
None,
|
||||
None,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
b_a = b_a * 1.5
|
||||
|
||||
out = torch.full(
|
||||
(max_num_tokens, hidden_dim),
|
||||
torch.nan,
|
||||
dtype=a.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
prepare_finalize.finalize(
|
||||
out,
|
||||
b_a,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
False,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ata.destroy()
|
||||
|
||||
num_tokens = a_chunk.shape[0]
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
|
||||
def _pplx_prepare_finalize(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: torch.Tensor,
|
||||
num_experts: int,
|
||||
):
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
device = pgi.device
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
k = a.shape[1]
|
||||
|
||||
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
|
||||
|
||||
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
|
||||
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
|
||||
a.dtype)
|
||||
|
||||
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
|
||||
num_experts)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||
# test below is able to deal with odd M.
|
||||
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@requires_pplx
|
||||
def test_pplx_prepare_finalize(
|
||||
mnk: tuple[int, int, int],
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
world_size, dp_size = world_dp_size
|
||||
device = "cuda"
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
||||
topk, e)
|
||||
|
||||
|
||||
def pplx_moe(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_compile: bool = True,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
device = torch.device("cuda", rank)
|
||||
hidden_dim = a.shape[1]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = 128
|
||||
topk = topk_ids.shape[1]
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
|
||||
ata = AllToAll.internode(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
)
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
ata,
|
||||
max_num_tokens,
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
||||
world_size=world_size,
|
||||
dp_size=dp_size)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||
|
||||
# Chunking weights like this only works for batched format
|
||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||
|
||||
if use_compile:
|
||||
_fused_experts = torch.compile(fused_experts,
|
||||
backend='inductor',
|
||||
fullgraph=True)
|
||||
else:
|
||||
_fused_experts = fused_experts
|
||||
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = _fused_experts(a_chunk,
|
||||
w1_chunk,
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ata.destroy()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
|
||||
prepare_finalize = BatchedPrepareAndFinalize(
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||
|
||||
out = fused_experts(
|
||||
a_chunk,
|
||||
# Chunking weights like this only works for batched format
|
||||
chunk_by_rank(w1, rank, world_size).to(device),
|
||||
chunk_by_rank(w2, rank, world_size).to(device),
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _pplx_moe(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
):
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
|
||||
m, k = a.shape
|
||||
e, _, n = w2.shape
|
||||
|
||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
||||
topk_weight, topk_ids)
|
||||
# TODO (bnell): fix + re-enable
|
||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||
# topk_ids)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@requires_pplx
|
||||
def test_pplx_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
world_size, dp_size = world_dp_size
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.platforms import current_platform
|
||||
@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
@ -137,6 +142,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
a,
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
deep_gemm_moe_fp8)
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
# only aligned sizes
|
||||
@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# only aligned sizes
|
||||
if (N % block_m != 0 or K % block_m != 0 or topk > E):
|
||||
pytest.skip(
|
||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if N <= 512:
|
||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
|
||||
@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_int8(x,
|
||||
@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
|
||||
@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
|
||||
"""
|
||||
import contextlib
|
||||
import gc
|
||||
import importlib.util
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||
supports_custom_op)
|
||||
run_once, supports_custom_op)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -936,9 +937,49 @@ def init_distributed_environment(
|
||||
"world group already initialized with a different world size")
|
||||
|
||||
|
||||
PPLX_DID_INIT: bool = False
|
||||
|
||||
|
||||
@run_once
|
||||
def pplx_init(rank, world_size):
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
|
||||
if has_pplx and world_size > 1:
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id, nvshmem_init)
|
||||
try:
|
||||
global PPLX_DID_INIT
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
|
||||
"world size=%d", rank, world_size)
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
uid_gpu = uid.cuda()
|
||||
get_world_group().broadcast(uid_gpu, src=0)
|
||||
uid = uid_gpu.to(device='cpu')
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, rank, world_size)
|
||||
PPLX_DID_INIT = True
|
||||
except Exception as ex:
|
||||
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
|
||||
|
||||
|
||||
@run_once
|
||||
def pplx_finalize():
|
||||
global PPLX_DID_INIT
|
||||
if PPLX_DID_INIT:
|
||||
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
_all_to_all_cache)
|
||||
_all_to_all_cache.destroy()
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
enable_expert_parallel: bool = False,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1041,10 +1082,14 @@ def initialize_model_parallel(
|
||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
||||
_EP.rank_in_group)
|
||||
|
||||
if enable_expert_parallel:
|
||||
pplx_init(rank, world_size)
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
enable_expert_parallel: bool = False,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
|
||||
get_world_group().device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size, backend)
|
||||
pipeline_model_parallel_size,
|
||||
enable_expert_parallel, backend)
|
||||
return
|
||||
|
||||
assert (
|
||||
@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TP
|
||||
|
||||
pplx_finalize()
|
||||
|
||||
if _TP:
|
||||
_TP.destroy()
|
||||
_TP = None
|
||||
|
||||
@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_tcp_uri
|
||||
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.7"):
|
||||
pg.shutdown()
|
||||
else:
|
||||
# Lazy import for non-CUDA backends.
|
||||
try:
|
||||
# pytorch <= 2.6
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
_shutdown_backend(pg)
|
||||
except ImportError:
|
||||
# pytorch >= 2.7
|
||||
pg.shutdown()
|
||||
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
|
||||
@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
|
||||
dtype=torch.int32)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
|
||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
|
||||
@ -38,8 +38,8 @@ if HAS_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
@ -49,4 +49,5 @@ if HAS_TRITON:
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"TritonExperts",
|
||||
]
|
||||
|
||||
@ -5,10 +5,176 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.c_strides1 = c_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides2 = c_strides2
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
# Note that K, N are transposed
|
||||
N, K = K, N
|
||||
workspace1 = M * topk * max(2 * N, K)
|
||||
workspace2 = M * topk * N
|
||||
return (workspace1, workspace2, self.out_dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
a1q = hidden_states
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
|
||||
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert a1q_scale is None or a1q_scale.dim(
|
||||
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1.shape[2], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2.shape[2], "W2 scale shape mismatch"
|
||||
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[
|
||||
0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[
|
||||
0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
assert self.ab_strides1.shape[0] == w1.shape[
|
||||
0], "AB Strides 1 expert number mismatch"
|
||||
assert self.c_strides1.shape[0] == w1.shape[
|
||||
0], "C Strides 1 expert number mismatch"
|
||||
assert self.ab_strides2.shape[0] == w2.shape[
|
||||
0], "AB Strides 2 expert number mismatch"
|
||||
assert self.c_strides2.shape[0] == w2.shape[
|
||||
0], "C Strides 2 expert number mismatch"
|
||||
assert self.out_dtype in [torch.half,
|
||||
torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
M = a1q.shape[0]
|
||||
_, N, K = w2.shape # because w1 + w2 are transposed
|
||||
device = a1q.device
|
||||
|
||||
assert w1.shape[1] == K
|
||||
assert global_num_experts != -1
|
||||
assert a1q_scale is not None
|
||||
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
|
||||
expert_map[topk_ids], -1)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.shape[1]
|
||||
|
||||
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
expert_offsets = torch.empty((global_num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((global_num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((global_num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
if expert_map is not None:
|
||||
a_map = torch.zeros((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
else:
|
||||
a_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2, a_map,
|
||||
c_map, global_num_experts, N, K)
|
||||
|
||||
a1q = _fp8_perm(a1q, a_map)
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
|
||||
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||
|
||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1,
|
||||
self.ab_strides1, self.ab_strides1, self.c_strides1)
|
||||
|
||||
self.activation(activation, c2, c1)
|
||||
|
||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
if expert_map is not None:
|
||||
c3.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2,
|
||||
self.ab_strides2, self.ab_strides2, self.c_strides2)
|
||||
|
||||
c3 = c3[c_map]
|
||||
|
||||
return c3
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
@ -17,7 +183,7 @@ def cutlass_moe_fp8(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
@ -59,7 +225,7 @@ def cutlass_moe_fp8(
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
- out_dtype (torch.dtype): The output tensor type.
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
@ -71,115 +237,36 @@ def cutlass_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[
|
||||
0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[
|
||||
0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
assert ab_strides1.shape[0] == w1_q.shape[
|
||||
0], "AB Strides 1 expert number mismatch"
|
||||
assert c_strides1.shape[0] == w1_q.shape[
|
||||
0], "C Strides 1 expert number mismatch"
|
||||
assert ab_strides2.shape[0] == w2_q.shape[
|
||||
0], "AB Strides 2 expert number mismatch"
|
||||
assert c_strides2.shape[0] == w2_q.shape[
|
||||
0], "C Strides 2 expert number mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
local_topk_ids = topk_ids_
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
|
||||
expert_map[topk_ids_], -1)
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
if apply_router_weight_on_input:
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
a = a * topk_weights.to(out_dtype)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
device = a_q.device
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
per_channel_quant=per_act_token,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
CutlassExpertsFp8(
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
out_dtype,
|
||||
),
|
||||
)
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map_initializer = torch.empty
|
||||
c2_initializer = torch.empty
|
||||
if expert_map is not None:
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
a_map_initializer = torch.zeros
|
||||
c2_initializer = torch.zeros
|
||||
|
||||
a_map = a_map_initializer((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
ab_strides1, c_strides1)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
# Gather tokens
|
||||
c2 = c2[c_map].view(m, topk, k)
|
||||
if not apply_router_weight_on_input:
|
||||
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
return c2.sum(dim=1)
|
||||
return fn(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.utils import round_up
|
||||
|
||||
@ -19,6 +20,19 @@ logger = init_logger(__name__)
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def deep_gemm_block_shape() -> list[int]:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
block = dg.get_m_alignment_for_contiguous_layout()
|
||||
return [block, block]
|
||||
|
||||
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int):
|
||||
align = deep_gemm_block_shape()[0]
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
||||
return False
|
||||
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
# Expert maps not supported yet.
|
||||
if expert_map is not None:
|
||||
logger.debug("DeepGemm disabled: expert map NYI.")
|
||||
return False
|
||||
|
||||
align = dg.get_m_alignment_for_contiguous_layout()
|
||||
M = hidden_states.shape[0]
|
||||
_, K, N = w2.shape
|
||||
|
||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||
# ops are available.
|
||||
if N <= 512:
|
||||
M = hidden_states.size(0)
|
||||
_, K, N = w2.size()
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
logger.debug("DeepGemm disabled: unalinged problem size.")
|
||||
return False
|
||||
|
||||
if align > M or N % align != 0 or K % align != 0:
|
||||
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
|
||||
logger.debug("DeepGemm disabled: invalid weight dtype(s).")
|
||||
return False
|
||||
|
||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||
and w2.is_contiguous())
|
||||
if (not hidden_states.is_contiguous() or not w1.is_contiguous()
|
||||
or not w2.is_contiguous()):
|
||||
logger.debug(
|
||||
"DeepGemm disabled: weights or activations not contiguous.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block_shape = deep_gemm_block_shape()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = M_sum * max(N * 2, K)
|
||||
workspace2 = M_sum * N
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
block_m: int,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
top_k_num = curr_topk_ids.shape[1]
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
import deep_gemm as dg
|
||||
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
assert global_num_experts != -1
|
||||
assert w2.size(1) == K
|
||||
|
||||
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
|
||||
a1q,
|
||||
a1q_scale,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
self.block_shape[0],
|
||||
)
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||
M_sum = a1q.size(0)
|
||||
workspace1 = _resize_cache(workspace13, (M_sum, N))
|
||||
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
|
||||
workspace3 = _resize_cache(workspace13, (M_sum, K))
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
|
||||
self.block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk_weight: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M, topk = topk_weight.shape
|
||||
K = curr_hidden.shape[1]
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
workspace3 = workspace3[inv_perm, ...]
|
||||
|
||||
return workspace3
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
assert expert_map is None, "Expert maps not supported yet"
|
||||
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||
block_shape = [block_m, block_m]
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||
# case these calls will be nops. Otherwise, they'll be performed every
|
||||
# time the layer is executed.
|
||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||
|
||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.empty(M_sum * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
|
||||
workspace2 = torch.empty((M_sum, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||
a1_scale, block_shape)
|
||||
|
||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||
curr_topk_ids, global_num_experts,
|
||||
expert_map, block_m)
|
||||
|
||||
# Adjust the intermediate cache size and config for the last chunk.
|
||||
# Note that in most cases we only have one chunk so the cache size
|
||||
# and config are already set correctly and do not need to be adjusted.
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
curr_M = sorted_token_ids.numel()
|
||||
workspace1 = _resize_cache(workspace1, (curr_M, N))
|
||||
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
|
||||
workspace3 = _resize_cache(workspace3, (curr_M, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
|
||||
expert_ids)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
|
||||
block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||
|
||||
_moe_unpermute_and_reduce(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
|
||||
|
||||
return out_hidden_states
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
|
||||
block_shape=deep_gemm_block_shape()),
|
||||
DeepGemmExperts(),
|
||||
)
|
||||
return fn(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
755
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Normal file
755
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Normal file
@ -0,0 +1,755 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused batched MoE kernel."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr):
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
if use_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
|
||||
None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_w8a8:
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
|
||||
offs_bsn = offs_n // group_n
|
||||
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
|
||||
offs_bsn * stride_bsn)
|
||||
# tensor-wise
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + expert_id)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(a_ptrs,
|
||||
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=mask_m,
|
||||
other=0.0)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:,
|
||||
None] * b_scale[None, :]
|
||||
else:
|
||||
if use_w8a8:
|
||||
# acc used to enable fp8_fast_accum
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
if use_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
@triton.jit
|
||||
def expert_triton_kernel(
|
||||
a_ptr, #[max_tokens, K]
|
||||
b_ptr, #[K, N]
|
||||
c_ptr, #[max_tokens, N]
|
||||
expert_id,
|
||||
compute_type: tl.constexpr,
|
||||
# Dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Quantization data
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# strides
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Blockwise quantization data
|
||||
group_n,
|
||||
group_k,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
# Kernel config
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr):
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N) % N
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
mask_m = offs_m < M
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||
|
||||
accumulator = moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Meta-parameters
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16)
|
||||
|
||||
# store in C
|
||||
offs_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batched_triton_kernel(
|
||||
a_ptr, # [E, max_num_tokens, K]
|
||||
b_ptr, # [E, K, N]
|
||||
c_ptr, # [E, max_num_tokens, N]
|
||||
expert_num_tokens, # [E]
|
||||
compute_type: tl.constexpr,
|
||||
# Dimensions
|
||||
max_num_tokens,
|
||||
K,
|
||||
N,
|
||||
# Quantization data
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ae,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_ce,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Blockwise quantization data
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
# Kernel config
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr):
|
||||
expert_id = tl.program_id(axis=0)
|
||||
e_num_tokens = tl.load(expert_num_tokens + expert_id)
|
||||
if e_num_tokens == 0:
|
||||
# Early exit
|
||||
return
|
||||
|
||||
pid_mn = tl.program_id(axis=1)
|
||||
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid_mn // num_pid_n
|
||||
pid_n = pid_mn % num_pid_n
|
||||
|
||||
cta_m_start = pid_m * BLOCK_M
|
||||
cta_n_start = pid_n * BLOCK_N
|
||||
if cta_m_start >= e_num_tokens:
|
||||
# Early exit
|
||||
return
|
||||
|
||||
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
|
||||
cta_n_size = min(BLOCK_N, N - cta_n_start)
|
||||
|
||||
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
|
||||
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
|
||||
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
|
||||
cta_n_start * stride_cn)
|
||||
|
||||
expert_triton_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_id,
|
||||
compute_type,
|
||||
cta_m_size, # M
|
||||
cta_n_size, # N
|
||||
K, # K
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# Strides
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Blockwise quantization data
|
||||
group_n,
|
||||
group_k,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
# Kernel config
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K)
|
||||
|
||||
|
||||
def invoke_moe_batched_triton_kernel(
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, K, N]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
assert not use_int4_w4a16
|
||||
max_num_tokens = A.size(1)
|
||||
K = A.size(2)
|
||||
N = C.size(2)
|
||||
|
||||
BLOCK_M = config['BLOCK_SIZE_M']
|
||||
BLOCK_N = config['BLOCK_SIZE_N']
|
||||
BLOCK_K = config['BLOCK_SIZE_K']
|
||||
assert (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()
|
||||
or max_num_tokens % BLOCK_M == 0)
|
||||
|
||||
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
||||
triton.cdiv(B.size(1), BLOCK_N))
|
||||
|
||||
batched_triton_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
expert_num_tokens,
|
||||
compute_type,
|
||||
# Dimensions
|
||||
max_num_tokens,
|
||||
K,
|
||||
N,
|
||||
# Quantization data
|
||||
A_scale,
|
||||
B_scale,
|
||||
B_zp,
|
||||
# Strides
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
# Blockwise quantization data
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
# Kernel config
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K)
|
||||
|
||||
|
||||
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
A reference prepare/finalize class that reorganizes the tokens into
|
||||
expert batched format, i.e. E x max_num_tokens x K. This is the format
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
||||
dp_size: int, rank: int):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.rank = rank
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
num_tokens, hidden_dim = a1.size()
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
if self.max_num_tokens is None:
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||
minlength=num_experts)
|
||||
self.max_num_tokens = int(tokens_per_expert.max().item())
|
||||
else:
|
||||
tokens_per_expert = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=a1.device)
|
||||
|
||||
assert num_experts % self.world_size == 0
|
||||
|
||||
num_local_experts = num_experts // self.world_size
|
||||
|
||||
b_a1 = torch.zeros(
|
||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||
dtype=a1.dtype,
|
||||
device=a1.device)
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks.flatten())
|
||||
b_a1[expert_id -
|
||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
||||
tokens_per_expert[expert_id - first_expert] = rows
|
||||
|
||||
return b_a1, a1_scale, tokens_per_expert
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
num_tokens = topk_ids.size(0)
|
||||
num_local_experts = fused_expert_output.size(0)
|
||||
K = fused_expert_output.size(-1)
|
||||
assert output.size(0) == num_tokens and output.size(1) == K
|
||||
|
||||
output.fill_(0)
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
matching_tokens = topk_ids == expert_id
|
||||
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks)
|
||||
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||
if not apply_router_weight_on_input:
|
||||
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||
output[topks] = output[topks] + rhs
|
||||
|
||||
|
||||
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
A reference MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_shape is None
|
||||
assert block_m is None
|
||||
assert not use_fp8_w8a8, "NYI"
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * K
|
||||
workspace2 = max_num_tokens * num_dp * N
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_num_tokens is not None
|
||||
hidden_dim = hidden_states.size(-1)
|
||||
|
||||
if self.max_num_tokens is None:
|
||||
max_num_tokens = hidden_states.size(1)
|
||||
else:
|
||||
max_num_tokens = self.max_num_tokens
|
||||
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_experts = global_num_experts
|
||||
out = _resize_cache(workspace13,
|
||||
(num_experts, max_num_tokens * num_dp, hidden_dim))
|
||||
num_local_experts = w1.size(0)
|
||||
assert num_local_experts == w1.size(0), (
|
||||
f"{num_local_experts} == {w1.size(0)}")
|
||||
|
||||
N = w1.size(1) // 2
|
||||
|
||||
# Not cudagraph friendly
|
||||
assert (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()
|
||||
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
|
||||
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
|
||||
|
||||
for expert in range(num_local_experts):
|
||||
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
||||
if (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()):
|
||||
num = max_num_tokens * num_dp
|
||||
else:
|
||||
num = int(expert_num_tokens[expert].item())
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
A Triton based MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
world_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.max_num_tokens = max_num_tokens
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
assert hidden_states.size(-1) == w1.size(2), (
|
||||
f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||
f"!= {w1.size(2)}")
|
||||
|
||||
assert hidden_states.is_contiguous(
|
||||
), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||
]
|
||||
|
||||
# TODO: num_tokens -> max_num_tokens?
|
||||
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
assert w1.size(0) == E
|
||||
assert w2.size(0) == E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(E, num_tokens, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
||||
|
||||
# MM1
|
||||
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||
B=w1,
|
||||
C=intermediate_cache1,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
compute_type=compute_type,
|
||||
A_scale=a1q_scale,
|
||||
B_scale=w1_scale,
|
||||
B_zp=w1_zp,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
config=config,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
# TODO: would be nice to use expert_num_tokens here to reduce
|
||||
# garbage compute
|
||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
#qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
# TODO (varun) : support w8a8
|
||||
assert not self.use_fp8_w8a8
|
||||
#if self.use_fp8_w8a8:
|
||||
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
# intermediate_cache2, a2_scale, self.block_shape)
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
||||
B=w2,
|
||||
C=intermediate_cache3,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
compute_type=compute_type,
|
||||
A_scale=a2q_scale,
|
||||
B_scale=w2_scale,
|
||||
B_zp=w2_zp,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
config=config,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.shape[0]
|
||||
num_tokens = M * top_k
|
||||
|
||||
@ -855,6 +870,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@ -865,9 +881,10 @@ def fused_topk(
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device)
|
||||
token_expert_indices = torch.empty(M,
|
||||
topk,
|
||||
@ -962,6 +979,20 @@ def get_config_dtype_str(
|
||||
return None
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_qtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
) -> Optional[torch.dtype]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
N = w1.shape[1]
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
@ -1171,60 +1206,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def moe_kernel_prepare_input(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
# If weights are per-channel (per_channel_quant=True), then
|
||||
# activations apply per-token quantization. Otherwise, assume
|
||||
# activation tensor-wise fp8 quantization, dynamic or static
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
|
||||
else:
|
||||
# activation block-wise fp8 quantization
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
# activation channel-wise int8 quantization
|
||||
assert (per_channel_quant
|
||||
), "int8 quantization only supports block or channel-wise"
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
# activation block-wise int8 quantization
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@ -1245,13 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_tokens = hidden_states.shape[0]
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
B=w1,
|
||||
A_scale=a1_scale,
|
||||
B_scale=w1_scale,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
qa1_scale,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
B=w2,
|
||||
A_scale=a2_scale,
|
||||
B_scale=w2_scale,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
qa2_scale,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
@ -1534,3 +1514,209 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.block_m = block_m
|
||||
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
factor = num_experts if a.dim() == 3 else 1
|
||||
workspace1 = M * topk * max(N * 2, K) * factor
|
||||
workspace2 = M * topk * N * factor
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
assert hidden_states.size(-1) == w1.size(2), \
|
||||
(f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||
f"!= {w1.size(2)}")
|
||||
|
||||
assert hidden_states.is_contiguous(
|
||||
), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dim() == 2
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||
]
|
||||
|
||||
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif hidden_states.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(num_tokens * top_k_num, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
|
||||
self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
qtype = get_config_qtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
quant_dtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -26,8 +30,17 @@ from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_moe import fused_experts
|
||||
from .fused_batched_moe import (BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .modular_kernel import (FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
if has_pplx:
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
@ -42,6 +55,179 @@ else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Note: this limit is somewhat arbitrary and might be changed later.
|
||||
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
||||
MOE_DP_CHUNK_SIZE = 256
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep and has_pplx
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||
object.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||
we simply return the sizes unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||
is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class MoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
in_dtype: torch.dtype # The activation type.
|
||||
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
@ -58,6 +244,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@ -80,12 +274,54 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AllToAllCache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def destroy(self):
|
||||
with self._lock:
|
||||
# TODO: can we do del self._cache?
|
||||
for _, a2a in self._cache.items():
|
||||
a2a.destroy()
|
||||
|
||||
def get_or_create(self, **kwargs):
|
||||
assert has_pplx
|
||||
import pplx_kernels as pplx
|
||||
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
# TODO (varun): Add support to switch to intranode
|
||||
# when all communications are within the same
|
||||
# node.
|
||||
logger.debug("Create AllToAll %s", kwargs)
|
||||
instance = pplx.AllToAll.internode(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
# Global singleton
|
||||
_all_to_all_cache = AllToAllCache()
|
||||
|
||||
|
||||
# Factory function as a cleaner interface
|
||||
def get_all_to_all(**kwargs):
|
||||
return _all_to_all_cache.get_or_create(**kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, moe: MoEConfig):
|
||||
super().__init__()
|
||||
self.fused_experts = fused_experts
|
||||
self.moe = moe
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
@ -193,6 +429,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -221,9 +498,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert not apply_router_weight_on_input
|
||||
assert expert_map is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -232,8 +512,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return fused_experts(
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -243,7 +523,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
@ -399,6 +680,45 @@ def determine_expert_map(
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
|
||||
def _construct_prepare_finalize(
|
||||
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
max_num_tokens = MOE_DP_CHUNK_SIZE
|
||||
world_size = moe.ep_size
|
||||
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
|
||||
rank = moe.ep_rank
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
logger.debug("using PplxPrepareAndFinalize")
|
||||
|
||||
all_to_all = get_all_to_all(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) //
|
||||
moe.block_size * torch.float32.itemsize)))
|
||||
|
||||
return PplxPrepareAndFinalize(
|
||||
all_to_all,
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
dp_size=dp_size,
|
||||
quant_dtype=moe.in_dtype,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
@ -449,21 +769,16 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Note: here we guard against accessing the TP and DP groups when
|
||||
# uninitialized (this happens when testing)
|
||||
self.tp_size = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
|
||||
self.dp_size = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
self.dp_rank = (0
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
# Use expert parallelism instead of tensor parallelism?
|
||||
vllm_config = get_current_vllm_config()
|
||||
use_ep = (vllm_config.parallel_config.enable_expert_parallel
|
||||
and self.tp_size * self.dp_size > 1)
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=(tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size()),
|
||||
dp_size_=(dp_size if dp_size is not None else
|
||||
get_dp_group().world_size),
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
self.use_direct_call = self.dp_size == 1
|
||||
@ -474,28 +789,17 @@ class FusedMoE(torch.nn.Module):
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
if use_ep:
|
||||
# Set TP size to 1 to adjust for EP and adjust EP size and rank
|
||||
# for DP attention.
|
||||
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
|
||||
self.tp_rank = 0
|
||||
self.ep_size = self.tp_size * self.dp_size
|
||||
self.tp_size = 1
|
||||
|
||||
# Determine expert maps
|
||||
if self.use_ep:
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
|
||||
self.ep_rank = 0
|
||||
self.tp_size = self.tp_size * self.dp_size
|
||||
self.ep_size = 1
|
||||
self.local_num_experts = self.global_num_experts
|
||||
self.expert_map = None
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
|
||||
self.top_k = top_k
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.hidden_size = hidden_size
|
||||
@ -520,14 +824,40 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
moe = MoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
# TODO (bnell): this needs to be fixed for quantized types.
|
||||
in_dtype=params_dtype,
|
||||
)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
quant_method: Optional[QuantizeMethodBase] = None
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
UnquantizedFusedMoEMethod())
|
||||
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
assert self.quant_method is not None
|
||||
quant_method = quant_config.get_quant_method(self, prefix)
|
||||
# No pplx for quantized types yet.
|
||||
prepare_finalize = None
|
||||
|
||||
assert quant_method is not None
|
||||
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||
self.quant_method = quant_method
|
||||
|
||||
if prepare_finalize is not None:
|
||||
world_size = moe.ep_size
|
||||
dp_size = int(moe.ep_size // moe.dp_size)
|
||||
success = self.quant_method.set_prepare_finalize(
|
||||
dp_size, world_size, prepare_finalize)
|
||||
if not success:
|
||||
logger.warning("DP+EP not supported for %s.",
|
||||
type(self.quant_method))
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
@ -546,6 +876,38 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
@ -830,7 +1192,8 @@ class FusedMoE(torch.nn.Module):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None):
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
@ -846,21 +1209,52 @@ class FusedMoE(torch.nn.Module):
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
renormalize=renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
layer. The result of this function is typically used as
|
||||
the reduce_results argument to the module.
|
||||
When just tensor-parallel is used, it is not required to reduce
|
||||
the shared_experts results immediately. Instead we reduce at the
|
||||
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||
With EP and the pplx kernels - this is no longer viable as all
|
||||
GPU ranks in DP, produce the complete set of hidden_states.
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
return self.use_pplx_kernels
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
"""
|
||||
The pplx combine kernel reduces across GPU ranks by default.
|
||||
"""
|
||||
if self.use_pplx_kernels:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
if self.use_direct_call:
|
||||
@ -869,9 +1263,62 @@ class FusedMoE(torch.nn.Module):
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states)
|
||||
|
||||
ctx = get_forward_context()
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||
moe_dp_chunk_size_per_rank):
|
||||
chunk_start = chunk_start_
|
||||
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||
max_tokens_across_dp)
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
skip_result_store=chunk_start_ >= num_tokens)
|
||||
|
||||
return full_final_hidden_states
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
if self.moe_parallel_config.use_pplx_kernels:
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
if self.dp_size > 1:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
|
||||
364
vllm/model_executor/layers/fused_moe/modular_kernel.py
Normal file
364
vllm/model_executor/layers/fused_moe/modular_kernel.py
Normal file
@ -0,0 +1,364 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
# The goal is to be able to utilize different communication mechanisms with
|
||||
# any fused MoE kernel without needing to have combinatoric implementations.
|
||||
#
|
||||
# The fused moe kernels are broken down into the following components:
|
||||
#
|
||||
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
|
||||
#
|
||||
# Each component will be independent of the others except for
|
||||
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
|
||||
# mixed and matched with so that DP+EP can be supported easily for multiple
|
||||
# MoE kernel implementations.
|
||||
#
|
||||
# The following main classes are defined:
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method must apply weights and do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# MoE operation. One important feature to note is that this class does not
|
||||
# apply topk weights or reduce the final output.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
# communication mechanisms that need to be consistent.
|
||||
#
|
||||
|
||||
|
||||
def _moe_problem_size(
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
|
||||
Note: extracting the problem shape from the weight and activation tensors is
|
||||
not obvious. It needs to be done this way specifically due to subtle issues
|
||||
with particular kernels, e.g. the int4 kernels divide the trailing dimension
|
||||
by two, so it's not "correct" to extract N or K from the trailing dimension
|
||||
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
|
||||
to be kept in mind.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
K = w2.size(1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
# Make sure we are using the correct a1 (pre-permute).
|
||||
assert topk_ids.size(0) == a1.size(0), \
|
||||
f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
else:
|
||||
assert a1.dim() == 3
|
||||
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||
M = a1.size(1) # This is max_num_tokens
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
described above.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed
|
||||
for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
- expert_map: A tensor mapping expert indices from the global expert
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
- output: The output tensor, written in place. Must be (M, K) shape.
|
||||
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||
experts, it will have (M, topk, K) shape.
|
||||
- topk_weights: The weights to be applied to the fused_experts_output.
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
above.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
"""
|
||||
Compute the number of elements for the temporary outputs of the two
|
||||
gemms and activation in the fused expert function. Since the
|
||||
gemms are independent, the workspace for the first gemm can be shared
|
||||
with the workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- Number of workspace13 elements: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- Number of workspace2 elements: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def activation(self, activation: str, output: torch.Tensor,
|
||||
input: torch.Tensor) -> None:
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
|
||||
Parameters:
|
||||
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
|
||||
layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
||||
used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The unweighted, unreduced output tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||
is compatible with the `fused_experts` function in fused_moe.py.
|
||||
|
||||
It takes care of managing any required scratch space.
|
||||
|
||||
Note: Instances of this class should only be used for a single model
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
of weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
a1 = hidden_states
|
||||
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
output = a1 if inplace else torch.zeros_like(a1)
|
||||
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
|
||||
expert_map, apply_router_weight_on_input)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
||||
return output
|
||||
@ -3,6 +3,74 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
block_m: int,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
top_k_num = curr_topk_ids.size(1)
|
||||
|
||||
tokens_in_chunk = curr_hidden_states.sizze(0)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk_weight: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M, topk = topk_weight.size()
|
||||
K = curr_hidden.size(-1)
|
||||
if inv_perm is not None:
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
if not apply_router_weight_on_input:
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -42,7 +110,7 @@ def moe_permute(
|
||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||
the group which the j-th row of the LHS belong to.`
|
||||
"""
|
||||
n_token, n_hidden = hidden_states.shape
|
||||
n_token, n_hidden = hidden_states.size()
|
||||
assert (n_hidden * hidden_states.element_size()
|
||||
) % 16 == 0, "permue kernel need hidden dim align to 16B"
|
||||
permuted_row_size = n_token * topk
|
||||
@ -102,7 +170,7 @@ def moe_unpermute(
|
||||
- hidden_states (torch.Tensor): The reduced and unpermuted activation
|
||||
tensor.
|
||||
"""
|
||||
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
|
||||
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
|
||||
assert (n_hidden * permuted_hidden_states.element_size()
|
||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||
hidden_states = torch.empty((n_token, n_hidden),
|
||||
|
||||
147
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Normal file
147
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
|
||||
# Note use: layer.get_all_to_all() to get an AllToAll instance
|
||||
# The max_num_tokens, world_size and dp_size must be the same
|
||||
# as the ones used to create the AllToAll.
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(self,
|
||||
a2a: pplx.AllToAll,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
assert max_num_tokens > 0
|
||||
self.a2a = a2a
|
||||
self.block_shape = block_shape
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert rank_topk_ids.size(0) == num_tokens
|
||||
# assert expert_map is None, "NYI"
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||
self.quant_dtype,
|
||||
per_act_token,
|
||||
self.block_shape)
|
||||
|
||||
# rem_experts need to be 0 for pplx to work properly.
|
||||
rem_experts = num_experts % self.world_size
|
||||
assert rem_experts == 0
|
||||
num_local_experts = ((num_experts // self.world_size) +
|
||||
(1 if self.rank < rem_experts else 0))
|
||||
|
||||
expert_num_tokens = torch.empty(
|
||||
num_local_experts,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_dp = self.world_size // self.dp_size
|
||||
expert_x = torch.empty(
|
||||
(num_local_experts, self.max_num_tokens * num_dp, hidden_dim),
|
||||
dtype=a1q.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expert_x_scale: Optional[torch.Tensor] = None
|
||||
if a1q.dtype.itemsize == 1:
|
||||
float32_size = torch.float32.itemsize
|
||||
block_size = (self.block_shape[0] if self.block_shape is not None
|
||||
else 1) * float32_size
|
||||
expert_x_scale = torch.empty(
|
||||
(
|
||||
num_experts,
|
||||
expert_x.size(1),
|
||||
(expert_x.size(2) + block_size - 1) // block_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# This argument is optional, defaults to indices.size(0)
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
num_tokens = output.size(0) # M
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
assert topk_ids.size(0) == num_tokens, (
|
||||
f"{topk_ids.size(0)} == {num_tokens}")
|
||||
assert output.size(0) <= self.max_num_tokens, (
|
||||
f"{output.size(0)} <= {self.max_num_tokens}")
|
||||
assert output.size(1) == fused_expert_output.size(-1)
|
||||
|
||||
# Set weights to 1 if we did them in dispatch. This is hacky.
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
self.a2a.combine(out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m)
|
||||
60
vllm/model_executor/layers/fused_moe/prepare_finalize.py
Normal file
60
vllm/model_executor/layers/fused_moe/prepare_finalize.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.block_shape = block_shape
|
||||
self.quant_dtype = quant_dtype
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||
self.quant_dtype,
|
||||
self.per_channel_quant,
|
||||
self.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
_moe_unpermute_and_reduce(output, fused_expert_output, None,
|
||||
topk_weights, apply_router_weight_on_input)
|
||||
112
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
Normal file
112
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
allow_deep_gemm: bool = False):
|
||||
super().__init__()
|
||||
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
block_m=block_m)
|
||||
self.deep_gemm_expert = DeepGemmExperts()
|
||||
self.allow_deep_gemm = allow_deep_gemm
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||
return self.deep_gemm_expert.workspace_shapes(
|
||||
a, M, N, K, topk, num_experts)
|
||||
else:
|
||||
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
|
||||
num_experts)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
N = w1.size(1)
|
||||
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
return self.deep_gemm_expert.apply(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
)
|
||||
else:
|
||||
return self.triton_expert.apply(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
)
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
@ -15,26 +17,73 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
assert prod(
|
||||
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def _int8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
per_act_token: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform int8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
|
||||
# If weights are per-channel (per_channel_quant=True), then
|
||||
# activations apply per-token quantization. Otherwise, assume
|
||||
# activation tensor-wise fp8/int8 quantization, dynamic or static
|
||||
if block_shape is None:
|
||||
assert per_act_token, \
|
||||
"int8 quantization only supports block or channel-wise"
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
qtype: Optional[torch.dtype],
|
||||
per_channel_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if qtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
elif qtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
else:
|
||||
assert A_scale is None
|
||||
return A, A_scale
|
||||
|
||||
|
||||
@ -42,7 +91,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@ -9,6 +10,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.fused_experts = functools.partial(
|
||||
fused_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return False
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size)
|
||||
|
||||
if self.use_marlin:
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert not apply_router_weight_on_input, (
|
||||
@ -853,11 +883,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
@ -872,8 +902,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -79,7 +79,6 @@ class DbrxExperts(FusedMoE):
|
||||
prefix=prefix,
|
||||
)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.d_model = config.d_model
|
||||
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
|
||||
self.tp_size)
|
||||
|
||||
@ -31,9 +31,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -143,7 +141,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
@ -154,6 +153,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
if hidden_states.dtype != torch.float16:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@ -171,9 +171,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
final_hidden_states = final_hidden_states + shared_output \
|
||||
* (1. / self.routed_scaling_factor)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
@ -25,8 +25,7 @@ from transformers import Llama4TextConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -89,7 +88,7 @@ class Llama4MoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False, # We need to do scatter before reduce
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -102,7 +101,8 @@ class Llama4MoE(nn.Module):
|
||||
experts_out = routed_out + shared_out
|
||||
|
||||
if self.tp_size > 1:
|
||||
experts_out = tensor_model_parallel_all_reduce(experts_out)
|
||||
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
experts_out)
|
||||
|
||||
return experts_out
|
||||
|
||||
|
||||
@ -33,9 +33,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -129,7 +127,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
@ -156,7 +155,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
@ -30,9 +30,7 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -137,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
router_logits=router_logits)
|
||||
final_hidden_states = final_hidden_states
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
|
||||
"currently not supported with CUDA Graphs.")
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
compilation_config.use_cudagraph = False
|
||||
compilation_config.use_inductor = False
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
|
||||
@ -865,8 +865,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
|
||||
@ -341,7 +341,8 @@ def init_worker_distributed_environment(
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
@ -265,4 +265,5 @@ def init_tpu_worker_distributed_environment(
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
@ -390,7 +390,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size in bytes of a single KV cache block.
|
||||
|
||||
@ -416,7 +416,8 @@ def init_worker_distributed_environment(
|
||||
backend='hccl')
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
@ -442,7 +443,8 @@ def init_worker_distributed_environment(
|
||||
torch.distributed.all_reduce(dummy_tensor_hpu)
|
||||
assert dummy_tensor_hpu.item() == parallel_config.world_size
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
|
||||
|
||||
@ -76,7 +76,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.enable_expert_parallel)
|
||||
|
||||
# Device initialization should happen after initializing the distributed
|
||||
# runtime.
|
||||
|
||||
@ -530,7 +530,8 @@ def init_worker_distributed_environment(
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
@ -176,7 +176,8 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.enable_expert_parallel)
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user