mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Modularize fused experts and integrate PPLX kernels (#15956)
This commit is contained in:
parent
418d2f8bfb
commit
f9c069c85e
@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
|||||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
if (num_tokens == 0) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
|||||||
@ -65,5 +65,19 @@
|
|||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
|
||||||
|
|||||||
@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (use_global_memory) {
|
if (use_global_memory) {
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||||
// tensors
|
// tensors
|
||||||
@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
cumsum_buffer.data_ptr<int32_t>());
|
cumsum_buffer.data_ptr<int32_t>());
|
||||||
});
|
});
|
||||||
} else if (use_i16) {
|
} else if (use_i16) {
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
// set dynamic shared mem
|
// set dynamic shared mem
|
||||||
auto kernel =
|
auto kernel =
|
||||||
@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
topk_ids.numel());
|
topk_ids.numel());
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||||
@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|||||||
TORCH_CHECK(num_experts == 256,
|
TORCH_CHECK(num_experts == 256,
|
||||||
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
"sgl_moe_align_block_size kernel only supports deepseek v3.");
|
||||||
|
|
||||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||||
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
|
||||||
// calc needed amount of shared mem for `cumsum` tensors
|
// calc needed amount of shared mem for `cumsum` tensors
|
||||||
auto options_int =
|
auto options_int =
|
||||||
|
|||||||
@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int TPB>
|
template <int TPB, typename IndType>
|
||||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
__launch_bounds__(TPB) __global__ void moeTopK(
|
||||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
const float* inputs_after_softmax,
|
||||||
|
const bool* finished,
|
||||||
|
float* output,
|
||||||
|
IndType* indices,
|
||||||
|
int* source_rows,
|
||||||
|
const int num_experts,
|
||||||
|
const int k,
|
||||||
|
const int start_expert,
|
||||||
|
const int end_expert)
|
||||||
{
|
{
|
||||||
|
|
||||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||||
@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
|
|||||||
2) This implementation assumes k is small, but will work for any k.
|
2) This implementation assumes k is small, but will work for any k.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
|
||||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||||
{
|
{
|
||||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||||
@ -397,8 +405,8 @@ struct TopkConstants
|
|||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <int EXPERTS, int WARPS_PER_TB>
|
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
|
||||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||||
@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
|
template <typename IndType>
|
||||||
void topkGatingSoftmaxKernelLauncher(
|
void topkGatingSoftmaxKernelLauncher(
|
||||||
const float* gating_output,
|
const float* gating_output,
|
||||||
float* topk_weights,
|
float* topk_weights,
|
||||||
int* topk_indicies,
|
IndType* topk_indicies,
|
||||||
int* token_expert_indices,
|
int* token_expert_indices,
|
||||||
float* softmax_workspace,
|
float* softmax_workspace,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
@ -493,14 +502,32 @@ void topk_softmax(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
|
||||||
gating_output.data_ptr<float>(),
|
if(topk_indices.scalar_type() == at::ScalarType::Int)
|
||||||
topk_weights.data_ptr<float>(),
|
{
|
||||||
topk_indices.data_ptr<int>(),
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||||
token_expert_indices.data_ptr<int>(),
|
gating_output.data_ptr<float>(),
|
||||||
softmax_workspace.data_ptr<float>(),
|
topk_weights.data_ptr<float>(),
|
||||||
num_tokens,
|
topk_indices.data_ptr<int>(),
|
||||||
num_experts,
|
token_expert_indices.data_ptr<int>(),
|
||||||
topk,
|
softmax_workspace.data_ptr<float>(),
|
||||||
stream);
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
|
||||||
|
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||||
|
gating_output.data_ptr<float>(),
|
||||||
|
topk_weights.data_ptr<float>(),
|
||||||
|
topk_indices.data_ptr<uint32_t>(),
|
||||||
|
token_expert_indices.data_ptr<int>(),
|
||||||
|
softmax_workspace.data_ptr<float>(),
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -65,11 +65,17 @@ def parse_args():
|
|||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Master node port")
|
help="Master node port")
|
||||||
|
parser.add_argument("--enforce-eager",
|
||||||
|
action='store_true',
|
||||||
|
help="Enforce eager mode execution.")
|
||||||
|
parser.add_argument("--trust-remote-code",
|
||||||
|
action='store_true',
|
||||||
|
help="Trust remote code.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||||
dp_master_port, GPUs_per_dp_rank):
|
dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
|
||||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||||
@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
|||||||
max_tokens=[16, 20][global_dp_rank % 2])
|
max_tokens=[16, 20][global_dp_rank % 2])
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model=model,
|
llm = LLM(
|
||||||
tensor_parallel_size=GPUs_per_dp_rank,
|
model=model,
|
||||||
enforce_eager=True,
|
tensor_parallel_size=GPUs_per_dp_rank,
|
||||||
enable_expert_parallel=True)
|
enforce_eager=enforce_eager,
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
@ -155,7 +164,8 @@ if __name__ == "__main__":
|
|||||||
proc = Process(target=main,
|
proc = Process(target=main,
|
||||||
args=(args.model, dp_size, local_dp_rank,
|
args=(args.model, dp_size, local_dp_rank,
|
||||||
global_dp_rank, dp_master_ip, dp_master_port,
|
global_dp_rank, dp_master_ip, dp_master_port,
|
||||||
tp_size))
|
tp_size, args.enforce_eager,
|
||||||
|
args.trust_remote_code))
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
|
|||||||
114
tests/kernels/moe/test_batched_moe.py
Normal file
114
tests/kernels/moe/test_batched_moe.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
invoke_moe_batched_triton_kernel)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchedMMConfig:
|
||||||
|
dtype: torch.dtype
|
||||||
|
num_experts: int
|
||||||
|
max_tokens_per_expert: int
|
||||||
|
K: int
|
||||||
|
N: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchedMMTensors:
|
||||||
|
A: torch.Tensor # [E, max_tokens, K]
|
||||||
|
B: torch.Tensor # [E, K, N] - column major
|
||||||
|
C: torch.Tensor # [E, max_tokens, N]
|
||||||
|
num_expert_tokens: torch.Tensor # [E]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_tensors(config: BatchedMMConfig):
|
||||||
|
A = torch.randn(
|
||||||
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||||
|
device="cuda",
|
||||||
|
dtype=config.dtype) / 10
|
||||||
|
B = torch.randn((config.num_experts, config.N, config.K),
|
||||||
|
device="cuda",
|
||||||
|
dtype=config.dtype)
|
||||||
|
C = torch.zeros(
|
||||||
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||||
|
device="cuda",
|
||||||
|
dtype=config.dtype)
|
||||||
|
num_expert_tokens = torch.randint(low=0,
|
||||||
|
high=config.max_tokens_per_expert,
|
||||||
|
size=(config.num_experts, ),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32)
|
||||||
|
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||||
|
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||||
|
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||||
|
num_experts = num_expert_tokens.size(0)
|
||||||
|
|
||||||
|
for e in range(num_experts):
|
||||||
|
num_tokens = num_expert_tokens_cpu[e]
|
||||||
|
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||||
|
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||||
|
[32, 64, 128, 192, 224, 256, 512])
|
||||||
|
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||||
|
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||||
|
@pytest.mark.parametrize("dtype",
|
||||||
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
|
N: int, dtype: torch.dtype):
|
||||||
|
|
||||||
|
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
||||||
|
tensors = BatchedMMTensors.make_tensors(config)
|
||||||
|
|
||||||
|
test_output = tensors.C
|
||||||
|
ref_output = test_output.clone()
|
||||||
|
|
||||||
|
compute_tl_dtype = {
|
||||||
|
torch.float16: tl.float16,
|
||||||
|
torch.bfloat16: tl.bfloat16,
|
||||||
|
torch.float32: tl.float32
|
||||||
|
}[test_output.dtype]
|
||||||
|
invoke_moe_batched_triton_kernel(
|
||||||
|
tensors.A,
|
||||||
|
tensors.B,
|
||||||
|
test_output,
|
||||||
|
tensors.num_expert_tokens,
|
||||||
|
compute_tl_dtype,
|
||||||
|
# Quantization data
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
# Quantization schemes
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
config={
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 16,
|
||||||
|
"BLOCK_SIZE_K": 16
|
||||||
|
})
|
||||||
|
|
||||||
|
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
||||||
|
tensors.num_expert_tokens)
|
||||||
|
|
||||||
|
rtol, atol = {
|
||||||
|
torch.float16: (6e-2, 6e-2),
|
||||||
|
torch.bfloat16: (6e-2, 6e-2),
|
||||||
|
torch.float32: (1e-2, 1e-2),
|
||||||
|
}[test_output.dtype]
|
||||||
|
|
||||||
|
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
||||||
@ -30,6 +30,11 @@ MNK_FACTORS = [
|
|||||||
(224, 3072, 1536),
|
(224, 3072, 1536),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
pipeline_parallel_size=1))
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MOETensors:
|
class MOETensors:
|
||||||
@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
||||||
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
||||||
'topk_weights': topk_weights,
|
'topk_weights': topk_weights,
|
||||||
'topk_ids_': topk_ids,
|
'topk_ids': topk_ids,
|
||||||
'ab_strides1': moe_tensors.ab_strides1,
|
'ab_strides1': moe_tensors.ab_strides1,
|
||||||
'c_strides1': moe_tensors.c_strides1,
|
'c_strides1': moe_tensors.c_strides1,
|
||||||
'ab_strides2': moe_tensors.ab_strides2,
|
'ab_strides2': moe_tensors.ab_strides2,
|
||||||
@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(vllm_config):
|
||||||
VllmConfig(parallel_config=ParallelConfig(
|
|
||||||
pipeline_parallel_size=1))):
|
|
||||||
|
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_ch)
|
per_out_ch)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||||
topk_weights, topk_ids = fused_topk(mt.a,
|
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
# Note that we are using the dequantized versions of the tensors.
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
# Using a, w1 and w2 directly results in minor output differences.
|
||||||
@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
|||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(vllm_config):
|
||||||
VllmConfig(parallel_config=ParallelConfig(
|
|
||||||
pipeline_parallel_size=1))):
|
|
||||||
|
|
||||||
dtype = torch.half
|
dtype = torch.half
|
||||||
|
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_ch)
|
per_out_ch)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
topk_weights, topk_ids = fused_topk(mt.a,
|
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
# Note that we are using the dequantized versions of the tensors.
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
# Using a, w1 and w2 directly results in minor output differences.
|
||||||
@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP(
|
|||||||
ep_size: int,
|
ep_size: int,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(vllm_config):
|
||||||
VllmConfig(parallel_config=ParallelConfig(
|
|
||||||
pipeline_parallel_size=1))):
|
|
||||||
|
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_channel)
|
per_out_channel)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||||
topk_weights, topk_ids = fused_topk(mt.a,
|
topk_weights, topk_ids, _ = fused_topk(mt.a,
|
||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
# Note that we are using the dequantized versions of the tensors.
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
# Using a, w1 and w2 directly results in minor output differences.
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||||
@ -32,6 +33,10 @@ NUM_EXPERTS = [8, 64]
|
|||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
@ -70,31 +75,33 @@ def test_fused_moe(
|
|||||||
else:
|
else:
|
||||||
e_map = None
|
e_map = None
|
||||||
|
|
||||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
with set_current_vllm_config(vllm_config):
|
||||||
iterative_output = iterative_moe(a,
|
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||||
w1,
|
iterative_output = iterative_moe(a,
|
||||||
w2,
|
w1,
|
||||||
score,
|
w2,
|
||||||
topk,
|
score,
|
||||||
global_num_experts=e,
|
topk,
|
||||||
expert_map=e_map,
|
global_num_experts=e,
|
||||||
renormalize=False)
|
expert_map=e_map,
|
||||||
|
renormalize=False)
|
||||||
|
|
||||||
# Pad the weight if moe padding is enabled
|
# Pad the weight if moe padding is enabled
|
||||||
if padding:
|
if padding:
|
||||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
triton_output = fused_moe(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
renormalize=False)
|
||||||
|
|
||||||
triton_output = fused_moe(a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
score,
|
|
||||||
topk,
|
|
||||||
global_num_experts=e,
|
|
||||||
expert_map=e_map,
|
|
||||||
renormalize=False)
|
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
torch.testing.assert_close(iterative_output,
|
torch.testing.assert_close(iterative_output,
|
||||||
torch_output,
|
torch_output,
|
||||||
@ -115,7 +122,6 @@ def test_fused_moe(
|
|||||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||||
ep_size: int, dtype: torch.dtype, group_size: int,
|
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||||
has_zp: bool, weight_bits: int):
|
has_zp: bool, weight_bits: int):
|
||||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
@ -194,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
else:
|
else:
|
||||||
e_map = None
|
e_map = None
|
||||||
|
|
||||||
triton_output = fused_moe(a,
|
with set_current_vllm_config(vllm_config):
|
||||||
w1_qweight,
|
triton_output = fused_moe(a,
|
||||||
w2_qweight,
|
w1_qweight,
|
||||||
score,
|
w2_qweight,
|
||||||
topk,
|
score,
|
||||||
renormalize=False,
|
topk,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
renormalize=False,
|
||||||
use_int8_w8a16=weight_bits == 8,
|
use_int4_w4a16=weight_bits == 4,
|
||||||
global_num_experts=e,
|
use_int8_w8a16=weight_bits == 8,
|
||||||
expert_map=e_map,
|
global_num_experts=e,
|
||||||
w1_scale=w1_scales,
|
expert_map=e_map,
|
||||||
w2_scale=w2_scales,
|
w1_scale=w1_scales,
|
||||||
w1_zp=w1_qzeros if has_zp else None,
|
w2_scale=w2_scales,
|
||||||
w2_zp=w2_qzeros if has_zp else None,
|
w1_zp=w1_qzeros if has_zp else None,
|
||||||
block_shape=[0, group_size])
|
w2_zp=w2_qzeros if has_zp else None,
|
||||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
block_shape=[0, group_size])
|
||||||
|
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@ -515,7 +523,8 @@ def test_fused_marlin_moe(
|
|||||||
|
|
||||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
|
||||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
with set_current_vllm_config(vllm_config):
|
||||||
|
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||||
|
|
||||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||||
a,
|
a,
|
||||||
|
|||||||
691
tests/kernels/moe/test_pplx_moe.py
Normal file
691
tests/kernels/moe/test_pplx_moe.py
Normal file
@ -0,0 +1,691 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests for the MOE layers.
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pplx_kernels import AllToAll
|
||||||
|
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||||
|
nvshmem_finalize, nvshmem_get_unique_id,
|
||||||
|
nvshmem_init)
|
||||||
|
has_pplx = True
|
||||||
|
except ImportError:
|
||||||
|
has_pplx = False
|
||||||
|
|
||||||
|
from torch.multiprocessing import (
|
||||||
|
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||||
|
get_default_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||||
|
(222, 2048, 1024)]
|
||||||
|
|
||||||
|
PPLX_MOE_COMBOS = [
|
||||||
|
(1, 128, 128),
|
||||||
|
(2, 128, 512),
|
||||||
|
(3, 1024, 2048),
|
||||||
|
(32, 128, 1024),
|
||||||
|
(45, 512, 2048),
|
||||||
|
(64, 1024, 1024),
|
||||||
|
(222, 1024, 2048),
|
||||||
|
]
|
||||||
|
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
EP_SIZE = [1, 4]
|
||||||
|
TOP_KS = [1, 2, 6]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
requires_pplx = pytest.mark.skipif(
|
||||||
|
not has_pplx,
|
||||||
|
reason="Requires PPLX kernels",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ProcessGroupInfo:
|
||||||
|
world_size: int
|
||||||
|
world_local_size: int
|
||||||
|
rank: int
|
||||||
|
node_rank: int
|
||||||
|
local_rank: int
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_parallel_launch(
|
||||||
|
local_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
world_local_size: int,
|
||||||
|
node_rank: int,
|
||||||
|
init_method: str,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
rank = node_rank * world_local_size + local_rank
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="cpu:gloo,cuda:nccl",
|
||||||
|
init_method=init_method,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
device_id=device,
|
||||||
|
)
|
||||||
|
barrier = torch.tensor([rank], device=device)
|
||||||
|
torch.distributed.all_reduce(barrier)
|
||||||
|
|
||||||
|
try:
|
||||||
|
worker(
|
||||||
|
ProcessGroupInfo(
|
||||||
|
world_size=world_size,
|
||||||
|
world_local_size=world_local_size,
|
||||||
|
rank=rank,
|
||||||
|
node_rank=node_rank,
|
||||||
|
local_rank=local_rank,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
print(ex)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_launch(
|
||||||
|
world_size: int,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
assert not kwargs
|
||||||
|
spawn(
|
||||||
|
_worker_parallel_launch,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
world_size,
|
||||||
|
0,
|
||||||
|
"tcp://localhost:29500",
|
||||||
|
worker,
|
||||||
|
) + args,
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_launch_from_env(
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Launches a worker function in parallel across all processes in the current
|
||||||
|
environment. The environment must have the following variables set:
|
||||||
|
- WORLD_SIZE: The total number of processes.
|
||||||
|
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
||||||
|
- NODE_RANK: The rank of the current
|
||||||
|
- MASTER_ADDR: The address of the master process.
|
||||||
|
- MASTER_PORT: The port of the master process.
|
||||||
|
"""
|
||||||
|
assert not kwargs
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
||||||
|
node_rank = int(os.environ["NODE_RANK"])
|
||||||
|
assert "MASTER_ADDR" in os.environ
|
||||||
|
assert "MASTER_PORT" in os.environ
|
||||||
|
spawn(
|
||||||
|
_worker_parallel_launch,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
world_local_size,
|
||||||
|
node_rank,
|
||||||
|
"env://",
|
||||||
|
worker,
|
||||||
|
) + args,
|
||||||
|
nprocs=world_local_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_prepare(
|
||||||
|
a: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
max_num_tokens: Optional[int] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
assert topk_ids.shape[0] == a.shape[0]
|
||||||
|
|
||||||
|
num_tokens, hidden_dim = a.shape
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
|
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||||
|
minlength=num_experts)
|
||||||
|
|
||||||
|
assert tokens_per_expert.numel() == num_experts
|
||||||
|
|
||||||
|
if max_num_tokens is None:
|
||||||
|
max_num_tokens = int(tokens_per_expert.max().item())
|
||||||
|
|
||||||
|
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
|
||||||
|
dtype=a.dtype,
|
||||||
|
device=a.device)
|
||||||
|
|
||||||
|
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
|
||||||
|
|
||||||
|
for token in range(num_tokens):
|
||||||
|
for j in range(topk):
|
||||||
|
expert_id = topk_ids[token, j]
|
||||||
|
idx = token_counts[expert_id]
|
||||||
|
b_a[expert_id, idx:idx + 1, :] = a[token, :]
|
||||||
|
token_counts[expert_id] = token_counts[expert_id] + 1
|
||||||
|
|
||||||
|
return b_a, tokens_per_expert
|
||||||
|
|
||||||
|
|
||||||
|
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = topk_ids.shape[0]
|
||||||
|
num_experts = b_out.shape[0]
|
||||||
|
K = b_out.shape[-1]
|
||||||
|
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
|
||||||
|
expert_counts = torch.zeros(num_experts,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=b_out.device)
|
||||||
|
for token in range(num_tokens):
|
||||||
|
expert_ids = topk_ids[token]
|
||||||
|
for i in range(expert_ids.numel()):
|
||||||
|
expert_id = expert_ids[i]
|
||||||
|
idx = expert_counts[expert_id]
|
||||||
|
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
|
||||||
|
1, :] * topk_weight[token, i]
|
||||||
|
expert_counts[expert_id] = expert_counts[expert_id] + 1
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def torch_batched_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
|
||||||
|
assert b_a.dim() == 3
|
||||||
|
num_tokens, topk = topk_ids.shape
|
||||||
|
_, max_num_tokens, K = b_a.shape
|
||||||
|
assert num_experts == b_a.shape[0] and w2.shape[1] == K
|
||||||
|
out = torch.zeros((num_experts, max_num_tokens, K),
|
||||||
|
dtype=b_a.dtype,
|
||||||
|
device=b_a.device)
|
||||||
|
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
|
||||||
|
dtype=b_a.dtype,
|
||||||
|
device=b_a.device)
|
||||||
|
for expert in range(num_experts):
|
||||||
|
num = tokens_per_expert[expert]
|
||||||
|
if num > 0:
|
||||||
|
torch.ops._C.silu_and_mul(
|
||||||
|
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
|
||||||
|
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
|
||||||
|
|
||||||
|
return torch_finalize(out, topk_weight, topk_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def batched_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
|
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
|
||||||
|
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||||
|
|
||||||
|
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||||
|
|
||||||
|
|
||||||
|
# Note: same as torch_moe but with fused_topk factored out.
|
||||||
|
def torch_moe2(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M, K = a.shape
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
|
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
for i in range(num_experts):
|
||||||
|
mask = (topk_ids == i).view(-1)
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = SiluAndMul()(
|
||||||
|
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||||
|
|
||||||
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||||
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
def test_fused_moe_batched_experts(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
|
||||||
|
torch.testing.assert_close(baseline_output,
|
||||||
|
torch_output,
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=0)
|
||||||
|
torch.testing.assert_close(baseline_output,
|
||||||
|
batched_output,
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def rank_chunk(num: int, r: int, w: int) -> int:
|
||||||
|
rem = num % w
|
||||||
|
return (num // w) + (1 if r < rem else 0)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||||
|
chunk = rank_chunk(t.shape[0], r, w)
|
||||||
|
return t[(r * chunk):(r + 1) * chunk]
|
||||||
|
|
||||||
|
|
||||||
|
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
num_experts: int) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize)
|
||||||
|
|
||||||
|
assert torch.cuda.current_device() == pgi.local_rank
|
||||||
|
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
num_tokens, hidden_dim = a.shape
|
||||||
|
block_size = 128
|
||||||
|
device = pgi.device
|
||||||
|
rank = pgi.rank
|
||||||
|
world_size = pgi.world_size
|
||||||
|
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||||
|
|
||||||
|
ata = AllToAll.internode(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_experts=num_experts,
|
||||||
|
experts_per_token=topk,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||||
|
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||||
|
((hidden_dim + block_size - 1) // block_size *
|
||||||
|
torch.float32.itemsize)),
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||||
|
|
||||||
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
|
ata,
|
||||||
|
max_num_tokens,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
dp_size,
|
||||||
|
a.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||||
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||||
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||||
|
|
||||||
|
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
|
||||||
|
a_chunk,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
num_experts,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
b_a = b_a * 1.5
|
||||||
|
|
||||||
|
out = torch.full(
|
||||||
|
(max_num_tokens, hidden_dim),
|
||||||
|
torch.nan,
|
||||||
|
dtype=a.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_finalize.finalize(
|
||||||
|
out,
|
||||||
|
b_a,
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
ata.destroy()
|
||||||
|
|
||||||
|
num_tokens = a_chunk.shape[0]
|
||||||
|
|
||||||
|
return out[:num_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
def _pplx_prepare_finalize(
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
):
|
||||||
|
uid = nvshmem_get_unique_id(
|
||||||
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
|
torch.distributed.broadcast(uid, src=0)
|
||||||
|
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||||
|
device = pgi.device
|
||||||
|
|
||||||
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
k = a.shape[1]
|
||||||
|
|
||||||
|
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
|
||||||
|
|
||||||
|
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
|
||||||
|
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
|
||||||
|
a.dtype)
|
||||||
|
|
||||||
|
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
|
||||||
|
num_experts)
|
||||||
|
|
||||||
|
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||||
|
pgi.world_size).to(pplx_output.device)
|
||||||
|
|
||||||
|
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
nvshmem_finalize()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||||
|
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||||
|
# test below is able to deal with odd M.
|
||||||
|
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||||
|
@requires_pplx
|
||||||
|
def test_pplx_prepare_finalize(
|
||||||
|
mnk: tuple[int, int, int],
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
m, n, k = mnk
|
||||||
|
world_size, dp_size = world_dp_size
|
||||||
|
device = "cuda"
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
||||||
|
topk, e)
|
||||||
|
|
||||||
|
|
||||||
|
def pplx_moe(
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_compile: bool = True,
|
||||||
|
use_cudagraphs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize)
|
||||||
|
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
hidden_dim = a.shape[1]
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
block_size = 128
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||||
|
|
||||||
|
ata = AllToAll.internode(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_experts=num_experts,
|
||||||
|
experts_per_token=topk,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||||
|
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||||
|
((hidden_dim + block_size - 1) // block_size *
|
||||||
|
torch.float32.itemsize)),
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||||
|
|
||||||
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
|
ata,
|
||||||
|
max_num_tokens,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
dp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size)
|
||||||
|
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||||
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||||
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||||
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||||
|
|
||||||
|
# Chunking weights like this only works for batched format
|
||||||
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||||
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||||
|
|
||||||
|
if use_compile:
|
||||||
|
_fused_experts = torch.compile(fused_experts,
|
||||||
|
backend='inductor',
|
||||||
|
fullgraph=True)
|
||||||
|
else:
|
||||||
|
_fused_experts = fused_experts
|
||||||
|
|
||||||
|
out = _fused_experts(a_chunk,
|
||||||
|
w1_chunk,
|
||||||
|
w2_chunk,
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
|
if use_cudagraphs:
|
||||||
|
out.fill_(0)
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
|
out = _fused_experts(a_chunk,
|
||||||
|
w1_chunk,
|
||||||
|
w2_chunk,
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
graph.replay()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
ata.destroy()
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
||||||
|
assert torch.cuda.current_device() == pgi.local_rank
|
||||||
|
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
device = pgi.device
|
||||||
|
rank = pgi.rank
|
||||||
|
world_size = pgi.world_size
|
||||||
|
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||||
|
|
||||||
|
prepare_finalize = BatchedPrepareAndFinalize(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
||||||
|
world_size=1,
|
||||||
|
dp_size=1)
|
||||||
|
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||||
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||||
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||||
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||||
|
|
||||||
|
out = fused_experts(
|
||||||
|
a_chunk,
|
||||||
|
# Chunking weights like this only works for batched format
|
||||||
|
chunk_by_rank(w1, rank, world_size).to(device),
|
||||||
|
chunk_by_rank(w2, rank, world_size).to(device),
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _pplx_moe(
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
):
|
||||||
|
uid = nvshmem_get_unique_id(
|
||||||
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
|
torch.distributed.broadcast(uid, src=0)
|
||||||
|
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||||
|
|
||||||
|
m, k = a.shape
|
||||||
|
e, _, n = w2.shape
|
||||||
|
|
||||||
|
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||||
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||||
|
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||||
|
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
||||||
|
topk_weight, topk_ids)
|
||||||
|
# TODO (bnell): fix + re-enable
|
||||||
|
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||||
|
# topk_ids)
|
||||||
|
|
||||||
|
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||||
|
pgi.world_size).to(pplx_output.device)
|
||||||
|
|
||||||
|
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
nvshmem_finalize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||||
|
@requires_pplx
|
||||||
|
def test_pplx_moe(
|
||||||
|
mnk: tuple[int, int, int],
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
m, n, k = mnk
|
||||||
|
world_size, dp_size = world_dp_size
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
|
||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
|
|||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||||
"""Matrix multiplication function that supports per-token input
|
"""Matrix multiplication function that supports per-token input
|
||||||
@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
|||||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
with set_current_vllm_config(vllm_config):
|
||||||
out = fused_moe(
|
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||||
a,
|
out = fused_moe(
|
||||||
w1,
|
a,
|
||||||
w2,
|
w1,
|
||||||
score,
|
w2,
|
||||||
topk,
|
score,
|
||||||
renormalize=False,
|
topk,
|
||||||
use_fp8_w8a8=True, # using fp8
|
renormalize=False,
|
||||||
per_channel_quant=True,
|
use_fp8_w8a8=True, # using fp8
|
||||||
w1_scale=w1_s,
|
per_channel_quant=True,
|
||||||
w2_scale=w2_s,
|
w1_scale=w1_s,
|
||||||
block_shape=None, # Not using block quantization
|
w2_scale=w2_s,
|
||||||
)
|
block_shape=None, # Not using block quantization
|
||||||
|
)
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
rel_diff = (torch.mean(
|
rel_diff = (torch.mean(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
|||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
deep_gemm_moe_fp8)
|
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
moe_align_block_size)
|
moe_align_block_size)
|
||||||
@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
|
|||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
# Test configurations
|
# Test configurations
|
||||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||||
NUM_TOKENS = [7, 83, 2048]
|
NUM_TOKENS = [7, 83, 2048]
|
||||||
@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
# Set the context to avoid lots of warning spam.
|
||||||
vllm_config = VllmConfig()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||||
|
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
# only aligned sizes
|
# only aligned sizes
|
||||||
@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
|||||||
block_size = [block_m, block_m]
|
block_size = [block_m, block_m]
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
# only aligned sizes
|
if topk > E:
|
||||||
if (N % block_m != 0 or K % block_m != 0 or topk > E):
|
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||||
pytest.skip(
|
|
||||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
|
||||||
|
|
||||||
if N <= 512:
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
|||||||
@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
|
|||||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||||
allow_module_level=True)
|
allow_module_level=True)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
# For test
|
# For test
|
||||||
def native_per_token_group_quant_int8(x,
|
def native_per_token_group_quant_int8(x,
|
||||||
@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
# Set the context to avoid lots of warning spam.
|
||||||
vllm_config = VllmConfig()
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
out = fused_moe(
|
out = fused_moe(
|
||||||
a,
|
a,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ If you only need to use the distributed environment without model/pipeline
|
|||||||
"""
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
|
import importlib.util
|
||||||
import pickle
|
import pickle
|
||||||
import weakref
|
import weakref
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@ -42,7 +43,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
|||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||||
supports_custom_op)
|
run_once, supports_custom_op)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -936,9 +937,49 @@ def init_distributed_environment(
|
|||||||
"world group already initialized with a different world size")
|
"world group already initialized with a different world size")
|
||||||
|
|
||||||
|
|
||||||
|
PPLX_DID_INIT: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@run_once
|
||||||
|
def pplx_init(rank, world_size):
|
||||||
|
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||||
|
|
||||||
|
if has_pplx and world_size > 1:
|
||||||
|
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||||
|
nvshmem_get_unique_id, nvshmem_init)
|
||||||
|
try:
|
||||||
|
global PPLX_DID_INIT
|
||||||
|
logger.debug(
|
||||||
|
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
|
||||||
|
"world size=%d", rank, world_size)
|
||||||
|
uid = nvshmem_get_unique_id(
|
||||||
|
) if rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
|
uid_gpu = uid.cuda()
|
||||||
|
get_world_group().broadcast(uid_gpu, src=0)
|
||||||
|
uid = uid_gpu.to(device='cpu')
|
||||||
|
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||||
|
nvshmem_init(uid, rank, world_size)
|
||||||
|
PPLX_DID_INIT = True
|
||||||
|
except Exception as ex:
|
||||||
|
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)
|
||||||
|
|
||||||
|
|
||||||
|
@run_once
|
||||||
|
def pplx_finalize():
|
||||||
|
global PPLX_DID_INIT
|
||||||
|
if PPLX_DID_INIT:
|
||||||
|
from pplx_kernels.nvshmem import nvshmem_finalize
|
||||||
|
logger.debug("PPLX NVSHMEM finalize")
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
|
_all_to_all_cache)
|
||||||
|
_all_to_all_cache.destroy()
|
||||||
|
nvshmem_finalize()
|
||||||
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
enable_expert_parallel: bool = False,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1041,10 +1082,14 @@ def initialize_model_parallel(
|
|||||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
||||||
_EP.rank_in_group)
|
_EP.rank_in_group)
|
||||||
|
|
||||||
|
if enable_expert_parallel:
|
||||||
|
pplx_init(rank, world_size)
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size: int,
|
tensor_model_parallel_size: int,
|
||||||
pipeline_model_parallel_size: int,
|
pipeline_model_parallel_size: int,
|
||||||
|
enable_expert_parallel: bool = False,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Helper to initialize model parallel groups if they are not initialized,
|
"""Helper to initialize model parallel groups if they are not initialized,
|
||||||
@ -1055,7 +1100,8 @@ def ensure_model_parallel_initialized(
|
|||||||
get_world_group().device_group)
|
get_world_group().device_group)
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(tensor_model_parallel_size,
|
initialize_model_parallel(tensor_model_parallel_size,
|
||||||
pipeline_model_parallel_size, backend)
|
pipeline_model_parallel_size,
|
||||||
|
enable_expert_parallel, backend)
|
||||||
return
|
return
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -1133,6 +1179,9 @@ def get_tensor_model_parallel_rank():
|
|||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none and destroy them."""
|
"""Set the groups to none and destroy them."""
|
||||||
global _TP
|
global _TP
|
||||||
|
|
||||||
|
pplx_finalize()
|
||||||
|
|
||||||
if _TP:
|
if _TP:
|
||||||
_TP.destroy()
|
_TP.destroy()
|
||||||
_TP = None
|
_TP = None
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from torch.distributed.rendezvous import rendezvous
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_tcp_uri
|
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -362,12 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
|
|||||||
Destroy ProcessGroup returned by
|
Destroy ProcessGroup returned by
|
||||||
stateless_init_torch_distributed_process_group().
|
stateless_init_torch_distributed_process_group().
|
||||||
"""
|
"""
|
||||||
# Lazy import for non-CUDA backends.
|
if is_torch_equal_or_newer("2.7"):
|
||||||
try:
|
pg.shutdown()
|
||||||
# pytorch <= 2.6
|
else:
|
||||||
|
# Lazy import for non-CUDA backends.
|
||||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||||
_shutdown_backend(pg)
|
_shutdown_backend(pg)
|
||||||
except ImportError:
|
|
||||||
# pytorch >= 2.7
|
|
||||||
pg.shutdown()
|
|
||||||
_unregister_process_group(pg.group_name)
|
_unregister_process_group(pg.group_name)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ batchsize_forward_time: defaultdict = defaultdict(list)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DPMetadata:
|
class DPMetadata:
|
||||||
|
max_tokens_across_dp_cpu: torch.Tensor
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor
|
cu_tokens_across_dp_cpu: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@ -90,8 +91,10 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||||
|
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||||
dp_metadata = DPMetadata(cu_tokens_across_dp_cpu)
|
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
|
|||||||
@ -38,8 +38,8 @@ if HAS_TRITON:
|
|||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||||
grouped_topk)
|
get_config_file_name, grouped_topk)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
@ -49,4 +49,5 @@ if HAS_TRITON:
|
|||||||
"grouped_topk",
|
"grouped_topk",
|
||||||
"cutlass_moe_fp8",
|
"cutlass_moe_fp8",
|
||||||
"cutlass_moe_fp4",
|
"cutlass_moe_fp4",
|
||||||
|
"TritonExperts",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -5,10 +5,176 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
|
MoEPrepareAndFinalizeNoEP)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ab_strides1 = ab_strides1
|
||||||
|
self.c_strides1 = c_strides1
|
||||||
|
self.ab_strides2 = ab_strides2
|
||||||
|
self.c_strides2 = c_strides2
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
# Note that K, N are transposed
|
||||||
|
N, K = K, N
|
||||||
|
workspace1 = M * topk * max(2 * N, K)
|
||||||
|
workspace2 = M * topk * N
|
||||||
|
return (workspace1, workspace2, self.out_dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
a1q = hidden_states
|
||||||
|
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
assert w1.dtype == torch.float8_e4m3fn
|
||||||
|
assert w2.dtype == torch.float8_e4m3fn
|
||||||
|
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
|
||||||
|
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
|
||||||
|
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||||
|
assert a1q_scale is None or a1q_scale.dim(
|
||||||
|
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
||||||
|
0], "Input scale shape mismatch"
|
||||||
|
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||||
|
1] == w1.shape[2], "W1 scale shape mismatch"
|
||||||
|
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||||
|
1] == w2.shape[2], "W2 scale shape mismatch"
|
||||||
|
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
||||||
|
assert w1.shape[0] == w1_scale.shape[
|
||||||
|
0], "w1 scales expert number mismatch"
|
||||||
|
assert w1.shape[0] == w2_scale.shape[
|
||||||
|
0], "w2 scales expert number mismatch"
|
||||||
|
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||||
|
assert self.ab_strides1.shape[0] == w1.shape[
|
||||||
|
0], "AB Strides 1 expert number mismatch"
|
||||||
|
assert self.c_strides1.shape[0] == w1.shape[
|
||||||
|
0], "C Strides 1 expert number mismatch"
|
||||||
|
assert self.ab_strides2.shape[0] == w2.shape[
|
||||||
|
0], "AB Strides 2 expert number mismatch"
|
||||||
|
assert self.c_strides2.shape[0] == w2.shape[
|
||||||
|
0], "C Strides 2 expert number mismatch"
|
||||||
|
assert self.out_dtype in [torch.half,
|
||||||
|
torch.bfloat16], "Invalid output dtype"
|
||||||
|
|
||||||
|
M = a1q.shape[0]
|
||||||
|
_, N, K = w2.shape # because w1 + w2 are transposed
|
||||||
|
device = a1q.device
|
||||||
|
|
||||||
|
assert w1.shape[1] == K
|
||||||
|
assert global_num_experts != -1
|
||||||
|
assert a1q_scale is not None
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
"Translate info from expert_map to topk_ids"
|
||||||
|
local_topk_ids = torch.where(expert_map[topk_ids] != -1,
|
||||||
|
expert_map[topk_ids], -1)
|
||||||
|
else:
|
||||||
|
local_topk_ids = topk_ids
|
||||||
|
|
||||||
|
topk = local_topk_ids.shape[1]
|
||||||
|
|
||||||
|
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
|
||||||
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
|
|
||||||
|
expert_offsets = torch.empty((global_num_experts + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes1 = torch.empty((global_num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes2 = torch.empty((global_num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
# With expert_map each Rank processes only a subset of experts. As
|
||||||
|
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||||
|
# zeros for correctness.
|
||||||
|
if expert_map is not None:
|
||||||
|
a_map = torch.zeros((local_topk_ids.numel()),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
a_map = torch.empty((local_topk_ids.numel()),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
c_map = torch.empty((local_topk_ids.numel()),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
|
||||||
|
problem_sizes1, problem_sizes2, a_map,
|
||||||
|
c_map, global_num_experts, N, K)
|
||||||
|
|
||||||
|
a1q = _fp8_perm(a1q, a_map)
|
||||||
|
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||||
|
|
||||||
|
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
||||||
|
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||||
|
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||||
|
|
||||||
|
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
|
||||||
|
expert_offsets[:-1], problem_sizes1,
|
||||||
|
self.ab_strides1, self.ab_strides1, self.c_strides1)
|
||||||
|
|
||||||
|
self.activation(activation, c2, c1)
|
||||||
|
|
||||||
|
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||||
|
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
c3.fill_(0)
|
||||||
|
|
||||||
|
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
|
||||||
|
expert_offsets[:-1], problem_sizes2,
|
||||||
|
self.ab_strides2, self.ab_strides2, self.c_strides2)
|
||||||
|
|
||||||
|
c3 = c3[c_map]
|
||||||
|
|
||||||
|
return c3
|
||||||
|
|
||||||
|
|
||||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||||
def cutlass_moe_fp8(
|
def cutlass_moe_fp8(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -17,7 +183,7 @@ def cutlass_moe_fp8(
|
|||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids_: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
ab_strides1: torch.Tensor,
|
ab_strides1: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
ab_strides2: torch.Tensor,
|
||||||
@ -59,7 +225,7 @@ def cutlass_moe_fp8(
|
|||||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
quantize the intermediate result between the gemms.
|
quantize the intermediate result between the gemms.
|
||||||
Shape: scalar or [M]
|
Shape: scalar or [M]
|
||||||
- out_dtype (torch.Tensor): The output tensor type.
|
- out_dtype (torch.dtype): The output tensor type.
|
||||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||||
every Rank is responsible for a subset of experts. expert_map is a
|
every Rank is responsible for a subset of experts. expert_map is a
|
||||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||||
@ -71,115 +237,36 @@ def cutlass_moe_fp8(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
|
||||||
assert w1_q.dtype == torch.float8_e4m3fn
|
|
||||||
assert w2_q.dtype == torch.float8_e4m3fn
|
|
||||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
|
||||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
|
||||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
|
||||||
assert a1_scale is None or a1_scale.dim(
|
|
||||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
|
||||||
0], "Input scale shape mismatch"
|
|
||||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
|
||||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
|
||||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
|
||||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
|
||||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
|
||||||
assert w1_q.shape[0] == w1_scale.shape[
|
|
||||||
0], "w1 scales expert number mismatch"
|
|
||||||
assert w1_q.shape[0] == w2_scale.shape[
|
|
||||||
0], "w2 scales expert number mismatch"
|
|
||||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
|
||||||
assert ab_strides1.shape[0] == w1_q.shape[
|
|
||||||
0], "AB Strides 1 expert number mismatch"
|
|
||||||
assert c_strides1.shape[0] == w1_q.shape[
|
|
||||||
0], "C Strides 1 expert number mismatch"
|
|
||||||
assert ab_strides2.shape[0] == w2_q.shape[
|
|
||||||
0], "AB Strides 2 expert number mismatch"
|
|
||||||
assert c_strides2.shape[0] == w2_q.shape[
|
|
||||||
0], "C Strides 2 expert number mismatch"
|
|
||||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
|
||||||
|
|
||||||
num_experts = w1_q.size(0)
|
|
||||||
m = a.size(0)
|
|
||||||
k = w1_q.size(1)
|
|
||||||
n = w2_q.size(1)
|
|
||||||
|
|
||||||
local_topk_ids = topk_ids_
|
|
||||||
if expert_map is not None:
|
|
||||||
"Translate info from expert_map to topk_ids"
|
|
||||||
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
|
|
||||||
expert_map[topk_ids_], -1)
|
|
||||||
|
|
||||||
topk = local_topk_ids.size(1)
|
|
||||||
|
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
if apply_router_weight_on_input:
|
|
||||||
assert topk == 1, \
|
|
||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
|
||||||
# TODO: this only works for topK=1, will need to update for topK>1
|
|
||||||
a = a * topk_weights.to(out_dtype)
|
|
||||||
|
|
||||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
fn = mk.FusedMoEModularKernel(
|
||||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
MoEPrepareAndFinalizeNoEP(
|
||||||
device = a_q.device
|
per_channel_quant=per_act_token,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
CutlassExpertsFp8(
|
||||||
|
ab_strides1,
|
||||||
|
c_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides2,
|
||||||
|
out_dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
expert_offsets = torch.empty((num_experts + 1),
|
return fn(
|
||||||
dtype=torch.int32,
|
a,
|
||||||
device=device)
|
w1_q,
|
||||||
problem_sizes1 = torch.empty((num_experts, 3),
|
w2_q,
|
||||||
dtype=torch.int32,
|
topk_weights,
|
||||||
device=device)
|
topk_ids,
|
||||||
problem_sizes2 = torch.empty((num_experts, 3),
|
expert_map=expert_map,
|
||||||
dtype=torch.int32,
|
w1_scale=w1_scale,
|
||||||
device=device)
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
a_map_initializer = torch.empty
|
a2_scale=a2_scale,
|
||||||
c2_initializer = torch.empty
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
if expert_map is not None:
|
)
|
||||||
# With expert_map each Rank processes only a subset of experts. As
|
|
||||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
|
||||||
# zeros for correctness.
|
|
||||||
a_map_initializer = torch.zeros
|
|
||||||
c2_initializer = torch.zeros
|
|
||||||
|
|
||||||
a_map = a_map_initializer((local_topk_ids.numel()),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
c_map = torch.empty((local_topk_ids.numel()),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
|
|
||||||
problem_sizes2, a_map, c_map, num_experts, n,
|
|
||||||
k)
|
|
||||||
|
|
||||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
|
||||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
|
||||||
|
|
||||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
|
||||||
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
|
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
|
||||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
|
||||||
ab_strides1, c_strides1)
|
|
||||||
|
|
||||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
|
||||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
|
||||||
|
|
||||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
|
||||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
|
||||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
|
||||||
ab_strides2, c_strides2)
|
|
||||||
# Gather tokens
|
|
||||||
c2 = c2[c_map].view(m, topk, k)
|
|
||||||
if not apply_router_weight_on_input:
|
|
||||||
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
|
|
||||||
return c2.sum(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
|||||||
@ -1,16 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import functools
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
moe_align_block_size)
|
_moe_permute)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
_fp8_quantize,
|
MoEPrepareAndFinalizeNoEP)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||||
_resize_cache)
|
_resize_cache)
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
|
||||||
@ -19,6 +20,19 @@ logger = init_logger(__name__)
|
|||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def deep_gemm_block_shape() -> list[int]:
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
block = dg.get_m_alignment_for_contiguous_layout()
|
||||||
|
return [block, block]
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_deep_gemm_shape(M: int, N: int, K: int):
|
||||||
|
align = deep_gemm_block_shape()[0]
|
||||||
|
return align <= M and N % align == 0 and K % align == 0
|
||||||
|
|
||||||
|
|
||||||
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -29,89 +43,112 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
|
|||||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||||
"""
|
"""
|
||||||
if not has_deep_gemm:
|
if not has_deep_gemm:
|
||||||
|
logger.debug("DeepGemm disabled: deep_gemm not available.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
|
||||||
import deep_gemm as dg
|
|
||||||
|
|
||||||
# Expert maps not supported yet.
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
|
logger.debug("DeepGemm disabled: expert map NYI.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
align = dg.get_m_alignment_for_contiguous_layout()
|
M = hidden_states.size(0)
|
||||||
M = hidden_states.shape[0]
|
_, K, N = w2.size()
|
||||||
_, K, N = w2.shape
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
|
logger.debug("DeepGemm disabled: unalinged problem size.")
|
||||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
|
||||||
# ops are available.
|
|
||||||
if N <= 512:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if align > M or N % align != 0 or K % align != 0:
|
if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn):
|
||||||
|
logger.debug("DeepGemm disabled: invalid weight dtype(s).")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
if (not hidden_states.is_contiguous() or not w1.is_contiguous()
|
||||||
and w2.is_contiguous())
|
or not w2.is_contiguous()):
|
||||||
|
logger.debug(
|
||||||
|
"DeepGemm disabled: weights or activations not contiguous.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _moe_permute(
|
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
curr_hidden_states: torch.Tensor,
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
curr_topk_ids: torch.Tensor,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
block_m: int,
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
|
||||||
Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
|
||||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
|
||||||
"""
|
|
||||||
top_k_num = curr_topk_ids.shape[1]
|
|
||||||
|
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.block_shape = deep_gemm_block_shape()
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
def workspace_shapes(
|
||||||
moe_align_block_size(curr_topk_ids,
|
self,
|
||||||
block_m,
|
a: torch.Tensor,
|
||||||
global_num_experts,
|
M: int,
|
||||||
expert_map,
|
N: int,
|
||||||
pad_sorted_ids=True))
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
block_m = self.block_shape[0]
|
||||||
|
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||||
|
M_sum = round_up(M_sum, block_m)
|
||||||
|
workspace1 = M_sum * max(N * 2, K)
|
||||||
|
workspace2 = M_sum * N
|
||||||
|
return (workspace1, workspace2, a.dtype)
|
||||||
|
|
||||||
inv_perm: Optional[torch.Tensor] = None
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
import deep_gemm as dg
|
||||||
|
|
||||||
num_tokens = top_k_num * tokens_in_chunk
|
a1q = hidden_states
|
||||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
_, N, K = w1.size()
|
||||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
|
||||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
|
||||||
|
|
||||||
# Permute according to sorted token ids.
|
assert global_num_experts != -1
|
||||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
assert w2.size(1) == K
|
||||||
sorted_token_ids // top_k_num)
|
|
||||||
|
|
||||||
if a1q_scale is not None:
|
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
|
||||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
a1q,
|
||||||
|
a1q_scale,
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
self.block_shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||||
inv_perm)
|
M_sum = a1q.size(0)
|
||||||
|
workspace1 = _resize_cache(workspace13, (M_sum, N))
|
||||||
|
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
|
||||||
|
workspace3 = _resize_cache(workspace13, (M_sum, K))
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
|
||||||
|
|
||||||
def _moe_unpermute_and_reduce(
|
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||||
out: torch.Tensor,
|
|
||||||
curr_hidden: torch.Tensor,
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
inv_perm: Optional[torch.Tensor],
|
|
||||||
topk_weight: torch.Tensor,
|
a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False,
|
||||||
) -> None:
|
self.block_shape)
|
||||||
"""
|
|
||||||
Unpermute the final result and apply topk_weights, then perform the final
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
reduction on the hidden states.
|
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||||
"""
|
|
||||||
M, topk = topk_weight.shape
|
workspace3 = workspace3[inv_perm, ...]
|
||||||
K = curr_hidden.shape[1]
|
|
||||||
curr_hidden = curr_hidden[inv_perm, ...]
|
return workspace3
|
||||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
|
||||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
|
||||||
ops.moe_sum(curr_hidden, out)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_gemm_moe_fp8(
|
def deep_gemm_moe_fp8(
|
||||||
@ -128,6 +165,7 @@ def deep_gemm_moe_fp8(
|
|||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input=False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
@ -166,129 +204,24 @@ def deep_gemm_moe_fp8(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
fn = mk.FusedMoEModularKernel(
|
||||||
import deep_gemm as dg
|
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
|
||||||
|
block_shape=deep_gemm_block_shape()),
|
||||||
assert expert_map is None, "Expert maps not supported yet"
|
DeepGemmExperts(),
|
||||||
|
)
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
return fn(
|
||||||
|
hidden_states,
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
w1,
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
w2,
|
||||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
topk_weights,
|
||||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
topk_ids,
|
||||||
assert hidden_states.dtype in [
|
inplace,
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
activation,
|
||||||
]
|
global_num_experts,
|
||||||
assert w1.dtype == torch.float8_e4m3fn
|
expert_map,
|
||||||
assert w2.dtype == torch.float8_e4m3fn
|
w1_scale=w1_scale,
|
||||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
w2_scale=w2_scale,
|
||||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
a1_scale=a1_scale,
|
||||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
a2_scale=a2_scale,
|
||||||
assert a1_scale is None or a1_scale.dim(
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
)
|
||||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
|
||||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
|
||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
K = w2.shape[1]
|
|
||||||
if global_num_experts == -1:
|
|
||||||
global_num_experts = E
|
|
||||||
|
|
||||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
|
||||||
# https://github.com/vllm-project/vllm/issues/5938
|
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
||||||
|
|
||||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
out_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
out_hidden_states = torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
|
||||||
block_shape = [block_m, block_m]
|
|
||||||
|
|
||||||
assert w1_scale is not None
|
|
||||||
assert w2_scale is not None
|
|
||||||
|
|
||||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
|
||||||
# case these calls will be nops. Otherwise, they'll be performed every
|
|
||||||
# time the layer is executed.
|
|
||||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
|
||||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
|
||||||
|
|
||||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
|
||||||
M_sum = round_up(M_sum, block_m)
|
|
||||||
|
|
||||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the time
|
|
||||||
# we need cache3, we're done with cache1
|
|
||||||
workspace13 = torch.empty(M_sum * max(N, K),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
|
|
||||||
workspace2 = torch.empty((M_sum, N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
|
|
||||||
|
|
||||||
for chunk in range(num_chunks):
|
|
||||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
|
||||||
min((chunk + 1) * CHUNK_SIZE,
|
|
||||||
num_tokens))
|
|
||||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
|
||||||
|
|
||||||
if tokens_in_chunk == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
|
||||||
|
|
||||||
a1q_scale: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
|
||||||
a1_scale, block_shape)
|
|
||||||
|
|
||||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
|
||||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
|
||||||
curr_topk_ids, global_num_experts,
|
|
||||||
expert_map, block_m)
|
|
||||||
|
|
||||||
# Adjust the intermediate cache size and config for the last chunk.
|
|
||||||
# Note that in most cases we only have one chunk so the cache size
|
|
||||||
# and config are already set correctly and do not need to be adjusted.
|
|
||||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
|
||||||
curr_M = sorted_token_ids.numel()
|
|
||||||
workspace1 = _resize_cache(workspace1, (curr_M, N))
|
|
||||||
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
|
|
||||||
workspace3 = _resize_cache(workspace3, (curr_M, K))
|
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
||||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
|
|
||||||
expert_ids)
|
|
||||||
|
|
||||||
if activation == "silu":
|
|
||||||
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
|
|
||||||
elif activation == "gelu":
|
|
||||||
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
|
||||||
|
|
||||||
a2q_scale: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
|
|
||||||
block_shape)
|
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
||||||
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
|
||||||
|
|
||||||
_moe_unpermute_and_reduce(
|
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
||||||
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
|
|
||||||
|
|
||||||
return out_hidden_states
|
|
||||||
|
|||||||
755
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Normal file
755
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Normal file
@ -0,0 +1,755 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Fused batched MoE kernel."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
get_config_dtype_str, try_get_optimal_moe_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_mmk(
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
K,
|
||||||
|
expert_id,
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
# The stride variables represent how much to increase the ptr by when
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Offsets and masks
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
mask_m,
|
||||||
|
# Block size for block-wise quantization
|
||||||
|
group_n: tl.constexpr,
|
||||||
|
group_k: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
compute_type: tl.constexpr,
|
||||||
|
use_w8a8: tl.constexpr,
|
||||||
|
use_w8a16: tl.constexpr):
|
||||||
|
|
||||||
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
if use_w8a16:
|
||||||
|
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[
|
||||||
|
None, :] * stride_bsn
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
|
if use_w8a8:
|
||||||
|
# block-wise
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
|
||||||
|
offs_bsn = offs_n // group_n
|
||||||
|
b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse +
|
||||||
|
offs_bsn * stride_bsn)
|
||||||
|
# tensor-wise
|
||||||
|
else:
|
||||||
|
a_scale = tl.load(a_scale_ptr)
|
||||||
|
b_scale = tl.load(b_scale_ptr + expert_id)
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Iterate to compute a block of the C matrix.
|
||||||
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
|
# of fp32 values for higher accuracy.
|
||||||
|
# `accumulator` will be converted back to fp16 after the loop.
|
||||||
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
|
# K dimension.
|
||||||
|
a = tl.load(a_ptrs,
|
||||||
|
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
|
||||||
|
other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
||||||
|
# We accumulate along the K dimension.
|
||||||
|
if use_w8a16:
|
||||||
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
|
elif use_w8a8:
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
k_start = k * BLOCK_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask,
|
||||||
|
mask=mask_m,
|
||||||
|
other=0.0)
|
||||||
|
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_scale[:,
|
||||||
|
None] * b_scale[None, :]
|
||||||
|
else:
|
||||||
|
if use_w8a8:
|
||||||
|
# acc used to enable fp8_fast_accum
|
||||||
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
|
else:
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
else:
|
||||||
|
accumulator += tl.dot(a, b)
|
||||||
|
# Advance the ptrs to the next K block.
|
||||||
|
a_ptrs += BLOCK_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_K * stride_bk
|
||||||
|
|
||||||
|
if use_w8a16:
|
||||||
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
|
elif use_w8a8:
|
||||||
|
if group_k > 0 and group_n > 0:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||||
|
else:
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
|
||||||
|
return accumulator
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def expert_triton_kernel(
|
||||||
|
a_ptr, #[max_tokens, K]
|
||||||
|
b_ptr, #[K, N]
|
||||||
|
c_ptr, #[max_tokens, N]
|
||||||
|
expert_id,
|
||||||
|
compute_type: tl.constexpr,
|
||||||
|
# Dimensions
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
# Quantization data
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
b_zp_ptr,
|
||||||
|
# strides
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Blockwise quantization data
|
||||||
|
group_n,
|
||||||
|
group_k,
|
||||||
|
# Quantization schemes
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
# Kernel config
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr):
|
||||||
|
|
||||||
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
mask_m = offs_m < M
|
||||||
|
|
||||||
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||||
|
|
||||||
|
accumulator = moe_mmk(
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
K,
|
||||||
|
expert_id,
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
# The stride variables represent how much to increase the ptr by when
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Offsets and masks
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
mask_m,
|
||||||
|
# Block size for block-wise quantization
|
||||||
|
group_n,
|
||||||
|
group_k,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
BLOCK_K,
|
||||||
|
compute_type,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16)
|
||||||
|
|
||||||
|
# store in C
|
||||||
|
offs_cn = tl.arange(0, BLOCK_N)
|
||||||
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||||
|
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def batched_triton_kernel(
|
||||||
|
a_ptr, # [E, max_num_tokens, K]
|
||||||
|
b_ptr, # [E, K, N]
|
||||||
|
c_ptr, # [E, max_num_tokens, N]
|
||||||
|
expert_num_tokens, # [E]
|
||||||
|
compute_type: tl.constexpr,
|
||||||
|
# Dimensions
|
||||||
|
max_num_tokens,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
# Quantization data
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
b_zp_ptr,
|
||||||
|
# The stride variables represent how much to increase the ptr by when
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
stride_ae,
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_be,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_ce,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Blockwise quantization data
|
||||||
|
group_n: tl.constexpr,
|
||||||
|
group_k: tl.constexpr,
|
||||||
|
# Quantization schemes
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
# Kernel config
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr):
|
||||||
|
expert_id = tl.program_id(axis=0)
|
||||||
|
e_num_tokens = tl.load(expert_num_tokens + expert_id)
|
||||||
|
if e_num_tokens == 0:
|
||||||
|
# Early exit
|
||||||
|
return
|
||||||
|
|
||||||
|
pid_mn = tl.program_id(axis=1)
|
||||||
|
#num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||||
|
pid_m = pid_mn // num_pid_n
|
||||||
|
pid_n = pid_mn % num_pid_n
|
||||||
|
|
||||||
|
cta_m_start = pid_m * BLOCK_M
|
||||||
|
cta_n_start = pid_n * BLOCK_N
|
||||||
|
if cta_m_start >= e_num_tokens:
|
||||||
|
# Early exit
|
||||||
|
return
|
||||||
|
|
||||||
|
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
|
||||||
|
cta_n_size = min(BLOCK_N, N - cta_n_start)
|
||||||
|
|
||||||
|
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
|
||||||
|
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
|
||||||
|
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
|
||||||
|
cta_n_start * stride_cn)
|
||||||
|
|
||||||
|
expert_triton_kernel(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
expert_id,
|
||||||
|
compute_type,
|
||||||
|
cta_m_size, # M
|
||||||
|
cta_n_size, # N
|
||||||
|
K, # K
|
||||||
|
a_scale_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
b_zp_ptr,
|
||||||
|
# Strides
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_asm,
|
||||||
|
stride_ask,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
# Blockwise quantization data
|
||||||
|
group_n,
|
||||||
|
group_k,
|
||||||
|
# Quantization schemes
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
# Kernel config
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
BLOCK_K)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_moe_batched_triton_kernel(
|
||||||
|
A: torch.Tensor, # [E, max_tokens, K]
|
||||||
|
B: torch.Tensor, # [E, K, N]
|
||||||
|
C: torch.Tensor, # [E, max_tokens, N]
|
||||||
|
expert_num_tokens: torch.Tensor, # [E]
|
||||||
|
compute_type: tl.dtype,
|
||||||
|
# Quantization data
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
B_zp: torch.Tensor,
|
||||||
|
# Quantization schemes
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
config: dict[str, int],
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
|
|
||||||
|
assert not use_int4_w4a16
|
||||||
|
max_num_tokens = A.size(1)
|
||||||
|
K = A.size(2)
|
||||||
|
N = C.size(2)
|
||||||
|
|
||||||
|
BLOCK_M = config['BLOCK_SIZE_M']
|
||||||
|
BLOCK_N = config['BLOCK_SIZE_N']
|
||||||
|
BLOCK_K = config['BLOCK_SIZE_K']
|
||||||
|
assert (torch.compiler.is_compiling()
|
||||||
|
or torch.cuda.is_current_stream_capturing()
|
||||||
|
or max_num_tokens % BLOCK_M == 0)
|
||||||
|
|
||||||
|
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
||||||
|
triton.cdiv(B.size(1), BLOCK_N))
|
||||||
|
|
||||||
|
batched_triton_kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
expert_num_tokens,
|
||||||
|
compute_type,
|
||||||
|
# Dimensions
|
||||||
|
max_num_tokens,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
# Quantization data
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
B_zp,
|
||||||
|
# Strides
|
||||||
|
A.stride(0),
|
||||||
|
A.stride(1),
|
||||||
|
A.stride(2),
|
||||||
|
B.stride(0),
|
||||||
|
B.stride(2),
|
||||||
|
B.stride(1),
|
||||||
|
C.stride(0),
|
||||||
|
C.stride(1),
|
||||||
|
C.stride(2),
|
||||||
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||||
|
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
# Blockwise quantization data
|
||||||
|
0 if block_shape is None else block_shape[0],
|
||||||
|
0 if block_shape is None else block_shape[1],
|
||||||
|
# Quantization schemes
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
# Kernel config
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
BLOCK_K=BLOCK_K)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
"""
|
||||||
|
A reference prepare/finalize class that reorganizes the tokens into
|
||||||
|
expert batched format, i.e. E x max_num_tokens x K. This is the format
|
||||||
|
that the PPLX dispatch/combine kernels use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
||||||
|
dp_size: int, rank: int):
|
||||||
|
super().__init__()
|
||||||
|
self.world_size = world_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
self.rank = rank
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
assert a1.dim() == 2
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
assert topk_ids.size(0) == a1.size(0)
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
|
assert topk == 1, \
|
||||||
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
|
num_tokens, hidden_dim = a1.size()
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
|
if self.max_num_tokens is None:
|
||||||
|
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||||
|
minlength=num_experts)
|
||||||
|
self.max_num_tokens = int(tokens_per_expert.max().item())
|
||||||
|
else:
|
||||||
|
tokens_per_expert = torch.zeros(num_experts,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=a1.device)
|
||||||
|
|
||||||
|
assert num_experts % self.world_size == 0
|
||||||
|
|
||||||
|
num_local_experts = num_experts // self.world_size
|
||||||
|
|
||||||
|
b_a1 = torch.zeros(
|
||||||
|
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||||
|
dtype=a1.dtype,
|
||||||
|
device=a1.device)
|
||||||
|
|
||||||
|
first_expert = num_local_experts * self.rank
|
||||||
|
last_expert = first_expert + num_local_experts
|
||||||
|
|
||||||
|
for expert_id in range(first_expert, last_expert):
|
||||||
|
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||||
|
rows = torch.count_nonzero(topks.flatten())
|
||||||
|
b_a1[expert_id -
|
||||||
|
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
||||||
|
tokens_per_expert[expert_id - first_expert] = rows
|
||||||
|
|
||||||
|
return b_a1, a1_scale, tokens_per_expert
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
fused_expert_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> None:
|
||||||
|
num_tokens = topk_ids.size(0)
|
||||||
|
num_local_experts = fused_expert_output.size(0)
|
||||||
|
K = fused_expert_output.size(-1)
|
||||||
|
assert output.size(0) == num_tokens and output.size(1) == K
|
||||||
|
|
||||||
|
output.fill_(0)
|
||||||
|
|
||||||
|
first_expert = num_local_experts * self.rank
|
||||||
|
last_expert = first_expert + num_local_experts
|
||||||
|
|
||||||
|
for expert_id in range(first_expert, last_expert):
|
||||||
|
matching_tokens = topk_ids == expert_id
|
||||||
|
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||||
|
rows = torch.count_nonzero(topks)
|
||||||
|
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||||
|
if not apply_router_weight_on_input:
|
||||||
|
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||||
|
output[topks] = output[topks] + rhs
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
"""
|
||||||
|
A reference MoE expert class that operates on expert batched format,
|
||||||
|
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||||
|
dispatch/combine kernels use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
world_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
max_num_tokens: Optional[int] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
block_m: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert block_shape is None
|
||||||
|
assert block_m is None
|
||||||
|
assert not use_fp8_w8a8, "NYI"
|
||||||
|
assert not use_int8_w8a8, "NYI"
|
||||||
|
assert not use_int8_w8a16, "NYI"
|
||||||
|
assert not use_int4_w4a16, "NYI"
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.world_size = world_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
assert a.dim() == 2
|
||||||
|
num_dp = self.world_size // self.dp_size
|
||||||
|
max_num_tokens = a.size(
|
||||||
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
|
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
|
||||||
|
workspace13 = num_experts * max_num_tokens * num_dp * K
|
||||||
|
workspace2 = max_num_tokens * num_dp * N
|
||||||
|
return (workspace13, workspace2, a.dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert hidden_states.dim() == 3
|
||||||
|
assert expert_num_tokens is not None
|
||||||
|
hidden_dim = hidden_states.size(-1)
|
||||||
|
|
||||||
|
if self.max_num_tokens is None:
|
||||||
|
max_num_tokens = hidden_states.size(1)
|
||||||
|
else:
|
||||||
|
max_num_tokens = self.max_num_tokens
|
||||||
|
|
||||||
|
num_dp = self.world_size // self.dp_size
|
||||||
|
num_experts = global_num_experts
|
||||||
|
out = _resize_cache(workspace13,
|
||||||
|
(num_experts, max_num_tokens * num_dp, hidden_dim))
|
||||||
|
num_local_experts = w1.size(0)
|
||||||
|
assert num_local_experts == w1.size(0), (
|
||||||
|
f"{num_local_experts} == {w1.size(0)}")
|
||||||
|
|
||||||
|
N = w1.size(1) // 2
|
||||||
|
|
||||||
|
# Not cudagraph friendly
|
||||||
|
assert (torch.compiler.is_compiling()
|
||||||
|
or torch.cuda.is_current_stream_capturing()
|
||||||
|
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
|
||||||
|
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
|
||||||
|
|
||||||
|
for expert in range(num_local_experts):
|
||||||
|
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
||||||
|
if (torch.compiler.is_compiling()
|
||||||
|
or torch.cuda.is_current_stream_capturing()):
|
||||||
|
num = max_num_tokens * num_dp
|
||||||
|
else:
|
||||||
|
num = int(expert_num_tokens[expert].item())
|
||||||
|
tmp = _resize_cache(workspace2, (num, N))
|
||||||
|
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||||
|
self.activation(activation, tmp, input)
|
||||||
|
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
"""
|
||||||
|
A Triton based MoE expert class that operates on expert batched format,
|
||||||
|
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||||
|
dispatch/combine kernels use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_tokens: Optional[int] = None,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
dp_size: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
self.use_int8_w8a8 = use_int8_w8a8
|
||||||
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
assert not use_int8_w8a8, "NYI"
|
||||||
|
assert not use_int4_w4a16, "NYI"
|
||||||
|
self.world_size = world_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
assert a.dim() == 2
|
||||||
|
num_dp = self.world_size // self.dp_size
|
||||||
|
max_num_tokens = a.size(
|
||||||
|
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||||
|
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||||
|
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||||
|
return (workspace13, workspace2, a.dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Check constraints.
|
||||||
|
if self.use_int4_w4a16:
|
||||||
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
|
"Hidden size mismatch")
|
||||||
|
else:
|
||||||
|
assert hidden_states.size(-1) == w1.size(2), (
|
||||||
|
f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||||
|
f"!= {w1.size(2)}")
|
||||||
|
|
||||||
|
assert hidden_states.is_contiguous(
|
||||||
|
), "Hidden_states must be contiguous"
|
||||||
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: num_tokens -> max_num_tokens?
|
||||||
|
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||||
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
|
assert w1.size(0) == E
|
||||||
|
assert w2.size(0) == E
|
||||||
|
|
||||||
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
config = try_get_optimal_moe_config(
|
||||||
|
w1.size(),
|
||||||
|
w2.size(),
|
||||||
|
top_k_num,
|
||||||
|
config_dtype,
|
||||||
|
num_tokens,
|
||||||
|
block_shape=self.block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_states.dtype == torch.bfloat16:
|
||||||
|
compute_type = tl.bfloat16
|
||||||
|
elif hidden_states.dtype == torch.float16:
|
||||||
|
compute_type = tl.float16
|
||||||
|
elif hidden_states.dtype == torch.float32:
|
||||||
|
compute_type = tl.float32
|
||||||
|
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||||
|
compute_type = tl.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
|
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
||||||
|
# We can reuse the memory between these because by the time we need
|
||||||
|
# cache3, we're done with cache1
|
||||||
|
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
|
||||||
|
intermediate_cache2 = _resize_cache(workspace2,
|
||||||
|
(E, num_tokens, N // 2))
|
||||||
|
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
||||||
|
|
||||||
|
# MM1
|
||||||
|
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||||
|
B=w1,
|
||||||
|
C=intermediate_cache1,
|
||||||
|
expert_num_tokens=expert_num_tokens,
|
||||||
|
compute_type=compute_type,
|
||||||
|
A_scale=a1q_scale,
|
||||||
|
B_scale=w1_scale,
|
||||||
|
B_zp=w1_zp,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
config=config,
|
||||||
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
# TODO: would be nice to use expert_num_tokens here to reduce
|
||||||
|
# garbage compute
|
||||||
|
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||||
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
|
#qintermediate_cache2 = intermediate_cache2
|
||||||
|
a2q_scale = a2_scale
|
||||||
|
# TODO (varun) : support w8a8
|
||||||
|
assert not self.use_fp8_w8a8
|
||||||
|
#if self.use_fp8_w8a8:
|
||||||
|
# qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||||
|
# intermediate_cache2, a2_scale, self.block_shape)
|
||||||
|
|
||||||
|
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
|
||||||
|
B=w2,
|
||||||
|
C=intermediate_cache3,
|
||||||
|
expert_num_tokens=expert_num_tokens,
|
||||||
|
compute_type=compute_type,
|
||||||
|
A_scale=a2q_scale,
|
||||||
|
B_scale=w2_scale,
|
||||||
|
B_zp=w2_zp,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
config=config,
|
||||||
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
return intermediate_cache3
|
||||||
@ -8,16 +8,17 @@ from typing import Any, Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
moe_align_block_size)
|
moe_align_block_size)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
per_token_group_quant_fp8)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
per_token_group_quant_int8, per_token_quant_int8)
|
_resize_cache, moe_kernel_quantize_input)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
@ -484,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
if use_fp8_w8a8 or use_int8_w8a8:
|
||||||
|
assert B_scale is not None
|
||||||
|
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||||
|
== B_scale.shape[-2])
|
||||||
|
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||||
|
== B_scale.shape[-1])
|
||||||
|
|
||||||
|
elif use_int8_w8a16 or use_int4_w4a16:
|
||||||
|
assert B_scale is not None
|
||||||
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
assert B_scale is None
|
||||||
|
|
||||||
M = A.shape[0]
|
M = A.shape[0]
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
|
|
||||||
@ -855,6 +870,7 @@ def fused_topk(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
|
indices_type: Optional[torch.dtype] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
"Number of tokens mismatch")
|
"Number of tokens mismatch")
|
||||||
@ -865,10 +881,11 @@ def fused_topk(
|
|||||||
topk,
|
topk,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
topk_ids = torch.empty(M,
|
topk_ids = torch.empty(
|
||||||
topk,
|
M,
|
||||||
dtype=torch.int32,
|
topk,
|
||||||
device=hidden_states.device)
|
dtype=torch.int32 if indices_type is None else indices_type,
|
||||||
|
device=hidden_states.device)
|
||||||
token_expert_indices = torch.empty(M,
|
token_expert_indices = torch.empty(M,
|
||||||
topk,
|
topk,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -962,6 +979,20 @@ def get_config_dtype_str(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (bnell): use scalar_type instead of bools?
|
||||||
|
def get_config_qtype(
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
) -> Optional[torch.dtype]:
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif use_int8_w8a8:
|
||||||
|
return torch.int8
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -1128,7 +1159,10 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||||
if (allow_deep_gemm and use_fp8_w8a8
|
# For now, disable DeepGemm for small N (<= 512) until better
|
||||||
|
# permute/unpermute ops are available.
|
||||||
|
N = w1.shape[1]
|
||||||
|
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
return deep_gemm_moe_fp8(
|
return deep_gemm_moe_fp8(
|
||||||
@ -1145,6 +1179,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return dispatch_fused_experts_func(inplace)(
|
return dispatch_fused_experts_func(inplace)(
|
||||||
@ -1171,87 +1206,37 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
def moe_kernel_prepare_input(
|
def fused_experts_impl(
|
||||||
A: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
B: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
B_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
use_fp8_w8a8: bool,
|
topk_ids: torch.Tensor,
|
||||||
use_int8_w8a8: bool,
|
inplace: bool = False,
|
||||||
use_int8_w8a16: bool,
|
activation: str = "silu",
|
||||||
use_int4_w4a16: bool,
|
apply_router_weight_on_input: bool = False,
|
||||||
per_channel_quant: bool,
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> torch.Tensor:
|
||||||
if use_fp8_w8a8:
|
|
||||||
assert B_scale is not None
|
|
||||||
if block_shape is None:
|
|
||||||
# If weights are per-channel (per_channel_quant=True), then
|
|
||||||
# activations apply per-token quantization. Otherwise, assume
|
|
||||||
# activation tensor-wise fp8 quantization, dynamic or static
|
|
||||||
A, A_scale = ops.scaled_fp8_quant(
|
|
||||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant)
|
|
||||||
else:
|
|
||||||
# activation block-wise fp8 quantization
|
|
||||||
assert len(block_shape) == 2
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
|
||||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
|
||||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
|
||||||
elif use_int8_w8a8:
|
|
||||||
assert B_scale is not None
|
|
||||||
if block_shape is None:
|
|
||||||
# activation channel-wise int8 quantization
|
|
||||||
assert (per_channel_quant
|
|
||||||
), "int8 quantization only supports block or channel-wise"
|
|
||||||
A, A_scale = per_token_quant_int8(A)
|
|
||||||
else:
|
|
||||||
# activation block-wise int8 quantization
|
|
||||||
assert len(block_shape) == 2
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
|
||||||
# assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
|
||||||
# assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
|
||||||
elif use_int8_w8a16 or use_int4_w4a16:
|
|
||||||
assert B_scale is not None
|
|
||||||
assert block_shape is None or block_shape[0] == 0
|
|
||||||
else:
|
|
||||||
assert A_scale is None
|
|
||||||
assert B_scale is None
|
|
||||||
|
|
||||||
return A, A_scale
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
per_channel_quant: bool = False,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None):
|
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if use_int4_w4a16:
|
if use_int4_w4a16:
|
||||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||||
2], "Hidden size mismatch"
|
2], "Hidden size mismatch"
|
||||||
else:
|
else:
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
assert hidden_states.shape[1] == w1.shape[2], (
|
||||||
|
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
@ -1261,7 +1246,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
torch.float32, torch.float16, torch.bfloat16
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
]
|
]
|
||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
num_tokens = hidden_states.shape[0]
|
||||||
E, N, _ = w1.shape
|
E, N, _ = w1.shape
|
||||||
K = w2.shape[1]
|
K = w2.shape[1]
|
||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
@ -1276,6 +1261,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16)
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.shape,
|
||||||
@ -1338,15 +1328,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||||
A=curr_hidden_states,
|
A=curr_hidden_states,
|
||||||
B=w1,
|
|
||||||
A_scale=a1_scale,
|
A_scale=a1_scale,
|
||||||
B_scale=w1_scale,
|
qtype=qtype,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
@ -1357,7 +1342,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
qa1_scale,
|
a1q_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w1_zp,
|
w1_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
@ -1384,22 +1369,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
qintermediate_cache2, qa2_scale = moe_kernel_prepare_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
B=w2,
|
|
||||||
A_scale=a2_scale,
|
A_scale=a2_scale,
|
||||||
B_scale=w2_scale,
|
qtype=qtype,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
qa2_scale,
|
a2q_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
w2_zp,
|
w2_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
@ -1534,3 +1514,209 @@ def fused_moe(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
block_m: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
self.use_int4_w4a16 = use_int4_w4a16
|
||||||
|
self.use_int8_w8a8 = use_int8_w8a8
|
||||||
|
self.use_int8_w8a16 = use_int8_w8a16
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.block_m = block_m
|
||||||
|
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16)
|
||||||
|
self.per_channel_quant = per_channel_quant
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
factor = num_experts if a.dim() == 3 else 1
|
||||||
|
workspace1 = M * topk * max(N * 2, K) * factor
|
||||||
|
workspace2 = M * topk * N * factor
|
||||||
|
return (workspace1, workspace2, a.dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Check constraints.
|
||||||
|
if self.use_int4_w4a16:
|
||||||
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
|
"Hidden size mismatch")
|
||||||
|
else:
|
||||||
|
assert hidden_states.size(-1) == w1.size(2), \
|
||||||
|
(f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||||
|
f"!= {w1.size(2)}")
|
||||||
|
|
||||||
|
assert hidden_states.is_contiguous(
|
||||||
|
), "Hidden_states must be contiguous"
|
||||||
|
assert hidden_states.dim() == 2
|
||||||
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
|
||||||
|
]
|
||||||
|
|
||||||
|
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||||
|
hidden_states, w1, w2, topk_ids)
|
||||||
|
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
|
||||||
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
config = try_get_optimal_moe_config(
|
||||||
|
w1.shape,
|
||||||
|
w2.shape,
|
||||||
|
top_k_num,
|
||||||
|
config_dtype,
|
||||||
|
num_tokens,
|
||||||
|
block_shape=self.block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hidden_states.dtype == torch.bfloat16:
|
||||||
|
compute_type = tl.bfloat16
|
||||||
|
elif hidden_states.dtype == torch.float16:
|
||||||
|
compute_type = tl.float16
|
||||||
|
elif hidden_states.dtype == torch.float32:
|
||||||
|
compute_type = tl.float32
|
||||||
|
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||||
|
compute_type = tl.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||||
|
|
||||||
|
# We can reuse the memory between these because by the time we need
|
||||||
|
# cache3, we're done with cache1
|
||||||
|
intermediate_cache1 = _resize_cache(workspace13,
|
||||||
|
(num_tokens, top_k_num, N))
|
||||||
|
intermediate_cache2 = _resize_cache(workspace2,
|
||||||
|
(num_tokens * top_k_num, N // 2))
|
||||||
|
intermediate_cache3 = _resize_cache(workspace13,
|
||||||
|
(num_tokens, top_k_num, K))
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
|
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||||
|
global_num_experts, expert_map))
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(hidden_states,
|
||||||
|
w1,
|
||||||
|
intermediate_cache1,
|
||||||
|
a1q_scale,
|
||||||
|
w1_scale,
|
||||||
|
w1_zp,
|
||||||
|
None,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
top_k_num,
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=self.use_int8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
per_channel_quant=self.per_channel_quant,
|
||||||
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
self.activation(activation, intermediate_cache2,
|
||||||
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
|
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
|
||||||
|
self.block_shape)
|
||||||
|
|
||||||
|
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||||
|
w2,
|
||||||
|
intermediate_cache3,
|
||||||
|
a2q_scale,
|
||||||
|
w2_scale,
|
||||||
|
w2_zp,
|
||||||
|
None,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
False,
|
||||||
|
1,
|
||||||
|
config,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=self.use_int8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
per_channel_quant=self.per_channel_quant,
|
||||||
|
block_shape=self.block_shape)
|
||||||
|
|
||||||
|
return intermediate_cache3
|
||||||
|
|
||||||
|
|
||||||
|
def modular_triton_fused_moe(
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> mk.FusedMoEModularKernel:
|
||||||
|
qtype = get_config_qtype(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
)
|
||||||
|
return mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(
|
||||||
|
quant_dtype=qtype,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
),
|
||||||
|
TritonExperts(
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@ -1,15 +1,19 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import threading
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -26,8 +30,17 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_moe import fused_experts
|
from .fused_batched_moe import (BatchedPrepareAndFinalize,
|
||||||
|
BatchedTritonExperts)
|
||||||
|
from .fused_moe import TritonExperts, fused_experts
|
||||||
|
from .modular_kernel import (FusedMoEModularKernel,
|
||||||
|
FusedMoEPermuteExpertsUnpermute,
|
||||||
|
FusedMoEPrepareAndFinalize)
|
||||||
|
if has_pplx:
|
||||||
|
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
@ -42,6 +55,179 @@ else:
|
|||||||
fused_moe_pallas = None # type: ignore
|
fused_moe_pallas = None # type: ignore
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Note: this limit is somewhat arbitrary and might be changed later.
|
||||||
|
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
|
||||||
|
MOE_DP_CHUNK_SIZE = 256
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FusedMoEParallelConfig:
|
||||||
|
tp_size: int
|
||||||
|
dp_size: int
|
||||||
|
ep_size: int
|
||||||
|
tp_rank: int
|
||||||
|
dp_rank: int
|
||||||
|
ep_rank: int
|
||||||
|
|
||||||
|
use_ep: bool # whether to use EP or not
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pplx_kernels(self):
|
||||||
|
return self.dp_size > 1 and self.use_ep and has_pplx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(tp_size_: int, dp_size_: int,
|
||||||
|
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||||
|
"""
|
||||||
|
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||||
|
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||||
|
level's of parallelism to use in the fused moe layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||||
|
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||||
|
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||||
|
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||||
|
object.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||||
|
we simply return the sizes unaltered and the ranks set to 0.
|
||||||
|
|
||||||
|
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||||
|
is non trivial.
|
||||||
|
|
||||||
|
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||||
|
legend : {size, rank}
|
||||||
|
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||||
|
- Comment : Tensors are sharded across 2 devices.
|
||||||
|
|
||||||
|
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||||
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
|
across 2 decvices.
|
||||||
|
|
||||||
|
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||||
|
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||||
|
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||||
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
|
across 4 devices.
|
||||||
|
|
||||||
|
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||||
|
- Comment: The experts are split between the 2 devices.
|
||||||
|
|
||||||
|
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||||
|
- Comment: There are 2 engine instances and the experts are split
|
||||||
|
between the 2 devices.
|
||||||
|
|
||||||
|
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||||
|
devices,
|
||||||
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||||
|
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||||
|
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||||
|
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||||
|
- Comment: There are 2 engine instances and the experts are split
|
||||||
|
between the 4 devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flatten_tp_across_dp(dp_rank: int):
|
||||||
|
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||||
|
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||||
|
# and tp_rank so we shard across all devices.
|
||||||
|
tp_size = dp_size_ * tp_size_
|
||||||
|
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||||
|
return tp_size, tp_rank
|
||||||
|
|
||||||
|
use_ep = (dp_size_ * tp_size_ > 1
|
||||||
|
and vllm_parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
|
dp_size = dp_size_
|
||||||
|
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||||
|
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||||
|
|
||||||
|
if not use_ep:
|
||||||
|
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
ep_size=1,
|
||||||
|
ep_rank=0,
|
||||||
|
use_ep=False)
|
||||||
|
# DP + EP / TP + EP / DP + TP + EP
|
||||||
|
assert use_ep
|
||||||
|
# In EP, each device owns a set of experts fully. There is no tensor
|
||||||
|
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||||
|
ep_size = tp_size
|
||||||
|
ep_rank = tp_rank
|
||||||
|
return FusedMoEParallelConfig(tp_size=1,
|
||||||
|
tp_rank=0,
|
||||||
|
dp_size=dp_size,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
ep_size=ep_size,
|
||||||
|
ep_rank=ep_rank,
|
||||||
|
use_ep=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||||
|
@dataclass
|
||||||
|
class MoEConfig:
|
||||||
|
num_experts: int
|
||||||
|
experts_per_token: int
|
||||||
|
hidden_dim: int
|
||||||
|
|
||||||
|
num_local_experts: int
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
|
in_dtype: torch.dtype # The activation type.
|
||||||
|
|
||||||
|
# TODO: add more quantization params, blocked, per-token, etc.
|
||||||
|
block_size: int = 128
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_size(self):
|
||||||
|
return self.moe_parallel_config.tp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_size(self):
|
||||||
|
return self.moe_parallel_config.dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_size(self):
|
||||||
|
return self.moe_parallel_config.ep_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_rank(self):
|
||||||
|
return self.moe_parallel_config.tp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_rank(self):
|
||||||
|
return self.moe_parallel_config.dp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_rank(self):
|
||||||
|
return self.moe_parallel_config.ep_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ep(self):
|
||||||
|
return self.moe_parallel_config.use_ep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pplx_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_pplx_kernels
|
||||||
|
|
||||||
|
|
||||||
class FusedMoeWeightScaleSupported(Enum):
|
class FusedMoeWeightScaleSupported(Enum):
|
||||||
TENSOR = "tensor"
|
TENSOR = "tensor"
|
||||||
@ -58,6 +244,14 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_prepare_finalize(
|
||||||
|
self,
|
||||||
|
dp_size: int,
|
||||||
|
world_size: int,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -80,12 +274,54 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AllToAllCache:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||||
|
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
with self._lock:
|
||||||
|
# TODO: can we do del self._cache?
|
||||||
|
for _, a2a in self._cache.items():
|
||||||
|
a2a.destroy()
|
||||||
|
|
||||||
|
def get_or_create(self, **kwargs):
|
||||||
|
assert has_pplx
|
||||||
|
import pplx_kernels as pplx
|
||||||
|
|
||||||
|
# Create a hashable key from the kwargs
|
||||||
|
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
instance = self._cache.get(key)
|
||||||
|
if instance is None:
|
||||||
|
# TODO (varun): Add support to switch to intranode
|
||||||
|
# when all communications are within the same
|
||||||
|
# node.
|
||||||
|
logger.debug("Create AllToAll %s", kwargs)
|
||||||
|
instance = pplx.AllToAll.internode(**kwargs)
|
||||||
|
self._cache[key] = instance
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton
|
||||||
|
_all_to_all_cache = AllToAllCache()
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function as a cleaner interface
|
||||||
|
def get_all_to_all(**kwargs):
|
||||||
|
return _all_to_all_cache.get_or_create(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("unquantized_fused_moe")
|
@CustomOp.register("unquantized_fused_moe")
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, moe: MoEConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.fused_experts = fused_experts
|
||||||
|
self.moe = moe
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
@ -193,6 +429,47 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
def set_prepare_finalize(
|
||||||
|
self,
|
||||||
|
dp_size: int,
|
||||||
|
world_size: int,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
) -> bool:
|
||||||
|
assert self.fused_experts == fused_experts
|
||||||
|
|
||||||
|
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||||
|
|
||||||
|
if isinstance(prepare_finalize,
|
||||||
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
|
experts = BatchedTritonExperts(
|
||||||
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
block_shape=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
|
experts = TritonExperts(
|
||||||
|
use_fp8_w8a8=False,
|
||||||
|
use_int8_w8a8=False,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
use_int4_w4a16=False,
|
||||||
|
block_shape=None,
|
||||||
|
per_channel_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -221,9 +498,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=torch.uint32 if self.moe.use_pplx_kernels else None)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
assert not apply_router_weight_on_input
|
||||||
|
assert expert_map is None
|
||||||
return self.rocm_aiter_fused_experts(
|
return self.rocm_aiter_fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@ -232,18 +512,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
else:
|
||||||
return fused_experts(
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
@ -399,6 +680,45 @@ def determine_expert_map(
|
|||||||
return (local_num_experts, expert_map)
|
return (local_num_experts, expert_map)
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_prepare_finalize(
|
||||||
|
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
|
max_num_tokens = MOE_DP_CHUNK_SIZE
|
||||||
|
world_size = moe.ep_size
|
||||||
|
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
|
||||||
|
rank = moe.ep_rank
|
||||||
|
|
||||||
|
if moe.use_pplx_kernels:
|
||||||
|
logger.debug("using PplxPrepareAndFinalize")
|
||||||
|
|
||||||
|
all_to_all = get_all_to_all(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_experts=moe.num_experts,
|
||||||
|
experts_per_token=moe.experts_per_token, # topk
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
hidden_dim=moe.hidden_dim,
|
||||||
|
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||||
|
# For blocked per token: set to
|
||||||
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
|
# For per-token: set to sizeof(float32)
|
||||||
|
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
||||||
|
((moe.hidden_dim + moe.block_size - 1) //
|
||||||
|
moe.block_size * torch.float32.itemsize)))
|
||||||
|
|
||||||
|
return PplxPrepareAndFinalize(
|
||||||
|
all_to_all,
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
quant_dtype=moe.in_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
@ -449,21 +769,16 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
# Note: here we guard against accessing the TP and DP groups when
|
|
||||||
# uninitialized (this happens when testing)
|
|
||||||
self.tp_size = (tp_size if tp_size is not None else
|
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank()
|
|
||||||
self.dp_size = (dp_size
|
|
||||||
if dp_size is not None else get_dp_group().world_size)
|
|
||||||
self.dp_rank = (0
|
|
||||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
|
||||||
self.global_num_experts = num_experts
|
|
||||||
|
|
||||||
# Use expert parallelism instead of tensor parallelism?
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
use_ep = (vllm_config.parallel_config.enable_expert_parallel
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||||
and self.tp_size * self.dp_size > 1)
|
FusedMoEParallelConfig.make(
|
||||||
|
tp_size_=(tp_size if tp_size is not None else
|
||||||
|
get_tensor_model_parallel_world_size()),
|
||||||
|
dp_size_=(dp_size if dp_size is not None else
|
||||||
|
get_dp_group().world_size),
|
||||||
|
vllm_parallel_config=vllm_config.parallel_config))
|
||||||
|
|
||||||
|
self.global_num_experts = num_experts
|
||||||
|
|
||||||
# For smuggling this layer into the fused moe custom op
|
# For smuggling this layer into the fused moe custom op
|
||||||
self.use_direct_call = self.dp_size == 1
|
self.use_direct_call = self.dp_size == 1
|
||||||
@ -474,28 +789,17 @@ class FusedMoE(torch.nn.Module):
|
|||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
self.layer_name = prefix
|
self.layer_name = prefix
|
||||||
|
|
||||||
if use_ep:
|
# Determine expert maps
|
||||||
# Set TP size to 1 to adjust for EP and adjust EP size and rank
|
if self.use_ep:
|
||||||
# for DP attention.
|
|
||||||
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
|
|
||||||
self.tp_rank = 0
|
|
||||||
self.ep_size = self.tp_size * self.dp_size
|
|
||||||
self.tp_size = 1
|
|
||||||
|
|
||||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||||
ep_size=self.ep_size,
|
ep_size=self.ep_size,
|
||||||
ep_rank=self.ep_rank,
|
ep_rank=self.ep_rank,
|
||||||
global_num_experts=self.global_num_experts)
|
global_num_experts=self.global_num_experts)
|
||||||
else:
|
else:
|
||||||
# Adjust TP size for DP attention
|
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||||
self.tp_rank = tp_rank + self.tp_size * self.dp_rank
|
None)
|
||||||
self.ep_rank = 0
|
|
||||||
self.tp_size = self.tp_size * self.dp_size
|
|
||||||
self.ep_size = 1
|
|
||||||
self.local_num_experts = self.global_num_experts
|
|
||||||
self.expert_map = None
|
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.global_num_experts = num_experts
|
|
||||||
|
|
||||||
assert intermediate_size % self.tp_size == 0
|
assert intermediate_size % self.tp_size == 0
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -520,14 +824,40 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
|
moe = MoEConfig(
|
||||||
|
num_experts=self.global_num_experts,
|
||||||
|
experts_per_token=top_k,
|
||||||
|
hidden_dim=hidden_size,
|
||||||
|
num_local_experts=self.local_num_experts,
|
||||||
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
|
# TODO (bnell): this needs to be fixed for quantized types.
|
||||||
|
in_dtype=params_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: get_quant_method will look at the layer's local_num_experts
|
# Note: get_quant_method will look at the layer's local_num_experts
|
||||||
# for heuristic purposes, so it must be initialized first.
|
# for heuristic purposes, so it must be initialized first.
|
||||||
|
quant_method: Optional[QuantizeMethodBase] = None
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
quant_method = UnquantizedFusedMoEMethod(moe)
|
||||||
UnquantizedFusedMoEMethod())
|
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
assert self.quant_method is not None
|
# No pplx for quantized types yet.
|
||||||
|
prepare_finalize = None
|
||||||
|
|
||||||
|
assert quant_method is not None
|
||||||
|
assert isinstance(quant_method, FusedMoEMethodBase)
|
||||||
|
self.quant_method = quant_method
|
||||||
|
|
||||||
|
if prepare_finalize is not None:
|
||||||
|
world_size = moe.ep_size
|
||||||
|
dp_size = int(moe.ep_size // moe.dp_size)
|
||||||
|
success = self.quant_method.set_prepare_finalize(
|
||||||
|
dp_size, world_size, prepare_finalize)
|
||||||
|
if not success:
|
||||||
|
logger.warning("DP+EP not supported for %s.",
|
||||||
|
type(self.quant_method))
|
||||||
|
|
||||||
moe_quant_params = {
|
moe_quant_params = {
|
||||||
"num_experts": self.local_num_experts,
|
"num_experts": self.local_num_experts,
|
||||||
@ -546,6 +876,38 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_size(self):
|
||||||
|
return self.moe_parallel_config.tp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_size(self):
|
||||||
|
return self.moe_parallel_config.dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_size(self):
|
||||||
|
return self.moe_parallel_config.ep_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tp_rank(self):
|
||||||
|
return self.moe_parallel_config.tp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dp_rank(self):
|
||||||
|
return self.moe_parallel_config.dp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ep_rank(self):
|
||||||
|
return self.moe_parallel_config.ep_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ep(self):
|
||||||
|
return self.moe_parallel_config.use_ep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_pplx_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_pplx_kernels
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||||
param: torch.nn.Parameter,
|
param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
@ -830,7 +1192,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
indices_type: Optional[torch.dtype] = None):
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
|
||||||
# DeekSeekv2 uses grouped_top_k
|
# DeekSeekv2 uses grouped_top_k
|
||||||
@ -846,21 +1209,52 @@ class FusedMoE(torch.nn.Module):
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
if indices_type is not None:
|
||||||
|
topk_ids = topk_ids.to(dtype=indices_type)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
renormalize=renormalize)
|
renormalize=renormalize,
|
||||||
|
indices_type=indices_type,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
renormalize=renormalize)
|
renormalize=renormalize)
|
||||||
|
if indices_type is not None:
|
||||||
|
topk_ids = topk_ids.to(dtype=indices_type)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||||
|
"""
|
||||||
|
The shared_experts are typically computed using the RowParallelLinear
|
||||||
|
layer. The result of this function is typically used as
|
||||||
|
the reduce_results argument to the module.
|
||||||
|
When just tensor-parallel is used, it is not required to reduce
|
||||||
|
the shared_experts results immediately. Instead we reduce at the
|
||||||
|
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
||||||
|
With EP and the pplx kernels - this is no longer viable as all
|
||||||
|
GPU ranks in DP, produce the complete set of hidden_states.
|
||||||
|
Therefore it is required that we reduce the shared_experts output
|
||||||
|
early.
|
||||||
|
"""
|
||||||
|
return self.use_pplx_kernels
|
||||||
|
|
||||||
|
def maybe_all_reduce_tensor_model_parallel(
|
||||||
|
self, final_hidden_states: torch.Tensor):
|
||||||
|
"""
|
||||||
|
The pplx combine kernel reduces across GPU ranks by default.
|
||||||
|
"""
|
||||||
|
if self.use_pplx_kernels:
|
||||||
|
return final_hidden_states
|
||||||
|
else:
|
||||||
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
@ -869,9 +1263,62 @@ class FusedMoE(torch.nn.Module):
|
|||||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||||
self.layer_name)
|
self.layer_name)
|
||||||
|
|
||||||
|
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||||
|
full_router_logits: torch.Tensor):
|
||||||
|
|
||||||
|
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||||
|
|
||||||
|
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||||
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||||
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
global_num_experts=self.global_num_experts,
|
||||||
|
expert_map=self.expert_map,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
scoring_func=self.scoring_func,
|
||||||
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
|
activation=self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_result_store:
|
||||||
|
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
ctx = get_forward_context()
|
||||||
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||||
|
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
|
||||||
|
|
||||||
|
num_tokens = full_hidden_states.size(0)
|
||||||
|
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||||
|
moe_dp_chunk_size_per_rank):
|
||||||
|
chunk_start = chunk_start_
|
||||||
|
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
||||||
|
max_tokens_across_dp)
|
||||||
|
# clamp start and end
|
||||||
|
chunk_start = min(chunk_start, num_tokens - 1)
|
||||||
|
chunk_end = min(chunk_end, num_tokens)
|
||||||
|
|
||||||
|
process_chunk(chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
skip_result_store=chunk_start_ >= num_tokens)
|
||||||
|
|
||||||
|
return full_final_hidden_states
|
||||||
|
|
||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
if self.moe_parallel_config.use_pplx_kernels:
|
||||||
|
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
hidden_states, router_logits = get_ep_group().dispatch(
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
|
|||||||
364
vllm/model_executor/layers/fused_moe/modular_kernel.py
Normal file
364
vllm/model_executor/layers/fused_moe/modular_kernel.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
#
|
||||||
|
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||||
|
# The goal is to be able to utilize different communication mechanisms with
|
||||||
|
# any fused MoE kernel without needing to have combinatoric implementations.
|
||||||
|
#
|
||||||
|
# The fused moe kernels are broken down into the following components:
|
||||||
|
#
|
||||||
|
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
|
||||||
|
#
|
||||||
|
# Each component will be independent of the others except for
|
||||||
|
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
|
||||||
|
# mixed and matched with so that DP+EP can be supported easily for multiple
|
||||||
|
# MoE kernel implementations.
|
||||||
|
#
|
||||||
|
# The following main classes are defined:
|
||||||
|
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||||
|
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||||
|
# The prepare method must take care of any needed quantization and the
|
||||||
|
# finalize method must apply weights and do the final reduction of the output.
|
||||||
|
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||||
|
# MoE operation. One important feature to note is that this class does not
|
||||||
|
# apply topk weights or reduce the final output.
|
||||||
|
# * FusedMoEModularKernel - an interface class that combines a
|
||||||
|
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||||
|
# provide the standard fused MoE kernel interface.
|
||||||
|
#
|
||||||
|
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||||
|
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||||
|
# communication mechanisms that need to be consistent.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_problem_size(
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> tuple[int, int, int, int, int]:
|
||||||
|
"""
|
||||||
|
Extract the MoE problem size from the given tensor arguments:
|
||||||
|
- a: The hidden states, input to the MoE layer.
|
||||||
|
- w1: The first set of expert weights.
|
||||||
|
- w2: The second set of expert weights.
|
||||||
|
- topk_ids: The topk ids.
|
||||||
|
|
||||||
|
Note: extracting the problem shape from the weight and activation tensors is
|
||||||
|
not obvious. It needs to be done this way specifically due to subtle issues
|
||||||
|
with particular kernels, e.g. the int4 kernels divide the trailing dimension
|
||||||
|
by two, so it's not "correct" to extract N or K from the trailing dimension
|
||||||
|
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
|
||||||
|
to be kept in mind.
|
||||||
|
"""
|
||||||
|
assert w1.dim() == 3 and w2.dim() == 3
|
||||||
|
E, N, _ = w1.size()
|
||||||
|
K = w2.size(1)
|
||||||
|
|
||||||
|
if a1.dim() == 2:
|
||||||
|
# Make sure we are using the correct a1 (pre-permute).
|
||||||
|
assert topk_ids.size(0) == a1.size(0), \
|
||||||
|
f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||||
|
M = a1.size(0)
|
||||||
|
else:
|
||||||
|
assert a1.dim() == 3
|
||||||
|
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
|
||||||
|
M = a1.size(1) # This is max_num_tokens
|
||||||
|
|
||||||
|
assert topk_ids.dim() == 2
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
|
return E, M, N, K, topk
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEPrepareAndFinalize(ABC):
|
||||||
|
"""
|
||||||
|
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||||
|
described above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform any quantization (and/or) dispatching needed
|
||||||
|
for this kernel.
|
||||||
|
- a1: The (unquantized) input to the MoE layer.
|
||||||
|
- a1_scale: Optional scales for a1
|
||||||
|
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||||
|
sure the quantization is consistent for both gemms.
|
||||||
|
- topk_ids: The topk ids.
|
||||||
|
- topk_weights: The topk weights.
|
||||||
|
- num_experts: The total number of experts in the global expert space.
|
||||||
|
- expert_map: A tensor mapping expert indices from the global expert
|
||||||
|
space to the local expert space of the expert parallel shard.
|
||||||
|
- apply_router_weight_on_input: When True, apply the weights to the
|
||||||
|
activations, before quantization + dispatching.
|
||||||
|
|
||||||
|
Returns a tuple of:
|
||||||
|
- quantized + dispatched a.
|
||||||
|
- quantized + dispatched a1_scales.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def finalize(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
fused_expert_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Perform any combine plus apply weights and perform a reduction on the
|
||||||
|
fused experts output.
|
||||||
|
- output: The output tensor, written in place. Must be (M, K) shape.
|
||||||
|
- fused_expert_output: The unweighted, unreduced output of the fused
|
||||||
|
experts, it will have (M, topk, K) shape.
|
||||||
|
- topk_weights: The weights to be applied to the fused_experts_output.
|
||||||
|
- topk_ids: The topk_ids.
|
||||||
|
- apply_router_weight_on_input: When False, apply the weights to
|
||||||
|
fused_expert_output.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||||
|
"""
|
||||||
|
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||||
|
above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
"""
|
||||||
|
Compute the number of elements for the temporary outputs of the two
|
||||||
|
gemms and activation in the fused expert function. Since the
|
||||||
|
gemms are independent, the workspace for the first gemm can be shared
|
||||||
|
with the workspace for the last gemm.
|
||||||
|
|
||||||
|
Returns a tuple of:
|
||||||
|
- Number of workspace13 elements: must be large enough to hold the
|
||||||
|
result of either expert gemm.
|
||||||
|
- Number of workspace2 elements: must be large enough to hold the
|
||||||
|
result of the activation function.
|
||||||
|
- Workspace type: The dtype to use for the workspace tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def activation(self, activation: str, output: torch.Tensor,
|
||||||
|
input: torch.Tensor) -> None:
|
||||||
|
assert output.size(-1) * 2 == input.size(-1)
|
||||||
|
if activation == "silu":
|
||||||
|
torch.ops._C.silu_and_mul(output, input)
|
||||||
|
elif activation == "gelu":
|
||||||
|
torch.ops._C.gelu_and_mul(output, input)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes the intermediate result of a Mixture of Experts
|
||||||
|
(MoE) layer using two sets of weights, w1 and w2.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
|
||||||
|
layer.
|
||||||
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
|
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||||
|
- activation (str): The activation function to apply after the first
|
||||||
|
MoE layer.
|
||||||
|
- global_num_experts (int): The total number of experts in the global
|
||||||
|
expert space.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
|
from the global expert space to the local expert space of the expert
|
||||||
|
parallel shard.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||||
|
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||||
|
w1.
|
||||||
|
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||||
|
w2.
|
||||||
|
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
||||||
|
used for a1.
|
||||||
|
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||||
|
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
||||||
|
must be large enough to hold output of either MoE gemm.
|
||||||
|
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||||
|
function.
|
||||||
|
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||||
|
assigned to each expert when using batched experts format input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The unweighted, unreduced output tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEModularKernel(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||||
|
a FusedMoEPermuteExpertsUnpermute to provide an interface that
|
||||||
|
is compatible with the `fused_experts` function in fused_moe.py.
|
||||||
|
|
||||||
|
It takes care of managing any required scratch space.
|
||||||
|
|
||||||
|
Note: Instances of this class should only be used for a single model
|
||||||
|
layer due to any layer specific state that may be used by the component
|
||||||
|
objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.prepare_finalize = prepare_finalize
|
||||||
|
self.fused_experts = fused_experts
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||||
|
of weights, w1 and w2, and top-k gating mechanism.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
|
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||||
|
the layer.
|
||||||
|
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||||
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
|
Defaults to False.
|
||||||
|
- activation (str): The activation function to apply after the first
|
||||||
|
MoE layer.
|
||||||
|
- global_num_experts (int): The total number of experts in the global
|
||||||
|
expert space.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
|
from the global expert space to the local expert space of the expert
|
||||||
|
parallel shard.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||||
|
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||||
|
w1.
|
||||||
|
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||||
|
w2.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
|
||||||
|
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||||
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
|
applied directly on the inputs. This is only applicable when topk is
|
||||||
|
1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
a1 = hidden_states
|
||||||
|
E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids)
|
||||||
|
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
|
||||||
|
output = a1 if inplace else torch.zeros_like(a1)
|
||||||
|
|
||||||
|
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||||
|
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
||||||
|
global_num_experts))
|
||||||
|
|
||||||
|
# We can reuse the memory between cache1 and cache3 because by the time
|
||||||
|
# we need cache3, we're done with cache1
|
||||||
|
workspace13 = torch.zeros(workspace13_shape,
|
||||||
|
device=a1.device,
|
||||||
|
dtype=workspace_dtype)
|
||||||
|
workspace2 = torch.zeros(workspace2_shape,
|
||||||
|
device=a1.device,
|
||||||
|
dtype=workspace_dtype)
|
||||||
|
|
||||||
|
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
|
||||||
|
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
|
||||||
|
expert_map, apply_router_weight_on_input)
|
||||||
|
|
||||||
|
fused_out = self.fused_experts.apply(
|
||||||
|
a1q,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
w1_zp=w1_zp,
|
||||||
|
w2_zp=w2_zp,
|
||||||
|
a1q_scale=a1q_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
workspace13=workspace13,
|
||||||
|
workspace2=workspace2,
|
||||||
|
expert_num_tokens=expert_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||||
|
topk_ids, apply_router_weight_on_input)
|
||||||
|
|
||||||
|
return output
|
||||||
@ -3,6 +3,74 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
|
moe_align_block_size)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_permute(
|
||||||
|
curr_hidden_states: torch.Tensor,
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
curr_topk_ids: torch.Tensor,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
block_m: int,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||||
|
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||||
|
"""
|
||||||
|
top_k_num = curr_topk_ids.size(1)
|
||||||
|
|
||||||
|
tokens_in_chunk = curr_hidden_states.sizze(0)
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
|
moe_align_block_size(curr_topk_ids,
|
||||||
|
block_m,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
pad_sorted_ids=True))
|
||||||
|
|
||||||
|
inv_perm: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
num_tokens = top_k_num * tokens_in_chunk
|
||||||
|
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||||
|
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||||
|
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||||
|
|
||||||
|
# Permute according to sorted token ids.
|
||||||
|
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||||
|
sorted_token_ids // top_k_num)
|
||||||
|
|
||||||
|
if a1q_scale is not None:
|
||||||
|
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||||
|
|
||||||
|
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
|
inv_perm)
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_unpermute_and_reduce(
|
||||||
|
out: torch.Tensor,
|
||||||
|
curr_hidden: torch.Tensor,
|
||||||
|
inv_perm: Optional[torch.Tensor],
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Unpermute the final result and apply topk_weights, then perform the final
|
||||||
|
reduction on the hidden states.
|
||||||
|
"""
|
||||||
|
M, topk = topk_weight.size()
|
||||||
|
K = curr_hidden.size(-1)
|
||||||
|
if inv_perm is not None:
|
||||||
|
curr_hidden = curr_hidden[inv_perm, ...]
|
||||||
|
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||||
|
if not apply_router_weight_on_input:
|
||||||
|
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||||
|
ops.moe_sum(curr_hidden, out)
|
||||||
|
|
||||||
|
|
||||||
def moe_permute(
|
def moe_permute(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -17,21 +85,21 @@ def moe_permute(
|
|||||||
fill_invalid_expert: int = -1
|
fill_invalid_expert: int = -1
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function expands and permutes activation to gather uncontinuous tokens
|
This function expands and permutes activation to gather uncontinuous tokens
|
||||||
for each expert.
|
for each expert.
|
||||||
Parameters:
|
Parameters:
|
||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
||||||
- topk_ids (torch.Tensor): topk expert route id for each token.
|
- topk_ids (torch.Tensor): topk expert route id for each token.
|
||||||
- token_expert_indices (torch.Tensor): indice for expanded hidden.
|
- token_expert_indices (torch.Tensor): indice for expanded hidden.
|
||||||
- topk (int): The number of top-k experts to select.
|
- topk (int): The number of top-k experts to select.
|
||||||
- n_expert (int): The number of expert.
|
- n_expert (int): The number of expert.
|
||||||
- n_local_expert (int): The number of expert in current EP rank.
|
- n_local_expert (int): The number of expert in current EP rank.
|
||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
from the global expert space to the local expert space of the expert
|
from the global expert space to the local expert space of the expert
|
||||||
parallel shard.
|
parallel shard.
|
||||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||||
to workaround DeepGemm unsupported -1 in m_indices
|
to workaround DeepGemm unsupported -1 in m_indices
|
||||||
Returns:
|
Returns:
|
||||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||||
@ -39,10 +107,10 @@ def moe_permute(
|
|||||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||||
expert_first_token_offset will align up to 'align_block_size'.
|
expert_first_token_offset will align up to 'align_block_size'.
|
||||||
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
|
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
|
||||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||||
the group which the j-th row of the LHS belong to.`
|
the group which the j-th row of the LHS belong to.`
|
||||||
"""
|
"""
|
||||||
n_token, n_hidden = hidden_states.shape
|
n_token, n_hidden = hidden_states.size()
|
||||||
assert (n_hidden * hidden_states.element_size()
|
assert (n_hidden * hidden_states.element_size()
|
||||||
) % 16 == 0, "permue kernel need hidden dim align to 16B"
|
) % 16 == 0, "permue kernel need hidden dim align to 16B"
|
||||||
permuted_row_size = n_token * topk
|
permuted_row_size = n_token * topk
|
||||||
@ -87,7 +155,7 @@ def moe_unpermute(
|
|||||||
n_local_expert: int,
|
n_local_expert: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function expands and permutes activation to gathering uncontinuous
|
This function expands and permutes activation to gathering uncontinuous
|
||||||
tokens for each expert.
|
tokens for each expert.
|
||||||
Parameters:
|
Parameters:
|
||||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||||
@ -99,10 +167,10 @@ def moe_unpermute(
|
|||||||
- n_expert (int): The number of expert.
|
- n_expert (int): The number of expert.
|
||||||
- n_local_expert (int): The number of expert in current EP rank.
|
- n_local_expert (int): The number of expert in current EP rank.
|
||||||
Returns:
|
Returns:
|
||||||
- hidden_states (torch.Tensor): The reduced and unpermuted activation
|
- hidden_states (torch.Tensor): The reduced and unpermuted activation
|
||||||
tensor.
|
tensor.
|
||||||
"""
|
"""
|
||||||
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
|
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
|
||||||
assert (n_hidden * permuted_hidden_states.element_size()
|
assert (n_hidden * permuted_hidden_states.element_size()
|
||||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||||
hidden_states = torch.empty((n_token, n_hidden),
|
hidden_states = torch.empty((n_token, n_hidden),
|
||||||
|
|||||||
147
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Normal file
147
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pplx_kernels as pplx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
moe_kernel_quantize_input)
|
||||||
|
|
||||||
|
|
||||||
|
# Note use: layer.get_all_to_all() to get an AllToAll instance
|
||||||
|
# The max_num_tokens, world_size and dp_size must be the same
|
||||||
|
# as the ones used to create the AllToAll.
|
||||||
|
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
a2a: pplx.AllToAll,
|
||||||
|
max_num_tokens: int,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
dp_size: int,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
block_shape: Optional[list[int]] = None):
|
||||||
|
super().__init__()
|
||||||
|
assert max_num_tokens > 0
|
||||||
|
self.a2a = a2a
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.world_size = world_size
|
||||||
|
self.rank = rank
|
||||||
|
self.dp_size = dp_size
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
rank_topk_weights: torch.Tensor,
|
||||||
|
rank_topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
num_tokens = a1.size(0) # M
|
||||||
|
hidden_dim = a1.size(-1) # K
|
||||||
|
|
||||||
|
assert rank_topk_ids.size(0) == num_tokens
|
||||||
|
# assert expert_map is None, "NYI"
|
||||||
|
|
||||||
|
# Is this always going to be a1.device?
|
||||||
|
device = a1.device
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
topk = rank_topk_ids.size(1)
|
||||||
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
|
assert topk == 1, (
|
||||||
|
"apply_router_weight_on_input is only implemented for topk=1")
|
||||||
|
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
|
|
||||||
|
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||||
|
self.quant_dtype,
|
||||||
|
per_act_token,
|
||||||
|
self.block_shape)
|
||||||
|
|
||||||
|
# rem_experts need to be 0 for pplx to work properly.
|
||||||
|
rem_experts = num_experts % self.world_size
|
||||||
|
assert rem_experts == 0
|
||||||
|
num_local_experts = ((num_experts // self.world_size) +
|
||||||
|
(1 if self.rank < rem_experts else 0))
|
||||||
|
|
||||||
|
expert_num_tokens = torch.empty(
|
||||||
|
num_local_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_dp = self.world_size // self.dp_size
|
||||||
|
expert_x = torch.empty(
|
||||||
|
(num_local_experts, self.max_num_tokens * num_dp, hidden_dim),
|
||||||
|
dtype=a1q.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
expert_x_scale: Optional[torch.Tensor] = None
|
||||||
|
if a1q.dtype.itemsize == 1:
|
||||||
|
float32_size = torch.float32.itemsize
|
||||||
|
block_size = (self.block_shape[0] if self.block_shape is not None
|
||||||
|
else 1) * float32_size
|
||||||
|
expert_x_scale = torch.empty(
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
expert_x.size(1),
|
||||||
|
(expert_x.size(2) + block_size - 1) // block_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This argument is optional, defaults to indices.size(0)
|
||||||
|
# There's not much point setting this unless it is != indices.size(0)
|
||||||
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
self.a2a.dispatch(
|
||||||
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
|
out_expert_x=expert_x,
|
||||||
|
out_expert_x_scale=expert_x_scale,
|
||||||
|
dp_x=a1q,
|
||||||
|
dp_x_scale=a1q_scale,
|
||||||
|
indices=rank_topk_ids,
|
||||||
|
bound_m=bound_m,
|
||||||
|
)
|
||||||
|
|
||||||
|
return expert_x, expert_x_scale, expert_num_tokens
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
fused_expert_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> None:
|
||||||
|
num_tokens = output.size(0) # M
|
||||||
|
# This argument is optional
|
||||||
|
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||||
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
assert topk_ids.size(0) == num_tokens, (
|
||||||
|
f"{topk_ids.size(0)} == {num_tokens}")
|
||||||
|
assert output.size(0) <= self.max_num_tokens, (
|
||||||
|
f"{output.size(0)} <= {self.max_num_tokens}")
|
||||||
|
assert output.size(1) == fused_expert_output.size(-1)
|
||||||
|
|
||||||
|
# Set weights to 1 if we did them in dispatch. This is hacky.
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
|
self.a2a.combine(out_tokens=output,
|
||||||
|
indices=topk_ids,
|
||||||
|
weights=topk_weights,
|
||||||
|
expert_y=fused_expert_output,
|
||||||
|
bound_m=bound_m)
|
||||||
60
vllm/model_executor/layers/fused_moe/prepare_finalize.py
Normal file
60
vllm/model_executor/layers/fused_moe/prepare_finalize.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
|
_moe_unpermute_and_reduce)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
moe_kernel_quantize_input)
|
||||||
|
|
||||||
|
|
||||||
|
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.per_channel_quant = per_channel_quant
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
# TODO: this only works for topK=1, will need to update for topK>1
|
||||||
|
assert topk == 1, \
|
||||||
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
|
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||||
|
self.quant_dtype,
|
||||||
|
self.per_channel_quant,
|
||||||
|
self.block_shape)
|
||||||
|
|
||||||
|
return a1q, a1q_scale, None
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
fused_expert_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> None:
|
||||||
|
_moe_unpermute_and_reduce(output, fused_expert_output, None,
|
||||||
|
topk_weights, apply_router_weight_on_input)
|
||||||
112
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
Normal file
112
vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
|
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
block_m: Optional[int] = None,
|
||||||
|
allow_deep_gemm: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
block_m=block_m)
|
||||||
|
self.deep_gemm_expert = DeepGemmExperts()
|
||||||
|
self.allow_deep_gemm = allow_deep_gemm
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||||
|
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||||
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
|
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||||
|
return self.deep_gemm_expert.workspace_shapes(
|
||||||
|
a, M, N, K, topk, num_experts)
|
||||||
|
else:
|
||||||
|
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
|
||||||
|
num_experts)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
N = w1.size(1)
|
||||||
|
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||||
|
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||||
|
return self.deep_gemm_expert.apply(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_ids,
|
||||||
|
activation,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
w1_zp,
|
||||||
|
w2_zp,
|
||||||
|
a1q_scale,
|
||||||
|
a2_scale,
|
||||||
|
workspace13,
|
||||||
|
workspace2,
|
||||||
|
expert_num_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.triton_expert.apply(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_ids,
|
||||||
|
activation,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
w1_zp,
|
||||||
|
w2_zp,
|
||||||
|
a1q_scale,
|
||||||
|
a2_scale,
|
||||||
|
workspace13,
|
||||||
|
workspace2,
|
||||||
|
expert_num_tokens,
|
||||||
|
)
|
||||||
@ -7,6 +7,8 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||||
|
per_token_group_quant_int8, per_token_quant_int8)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
@ -15,34 +17,81 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
|||||||
Shrink the given tensor and apply the given view to it. This is
|
Shrink the given tensor and apply the given view to it. This is
|
||||||
used to resize the intermediate fused_moe caches.
|
used to resize the intermediate fused_moe caches.
|
||||||
"""
|
"""
|
||||||
assert prod(v) <= x.numel()
|
assert prod(
|
||||||
|
v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly?
|
||||||
return x.flatten()[:prod(v)].view(*v)
|
return x.flatten()[:prod(v)].view(*v)
|
||||||
|
|
||||||
|
|
||||||
def _fp8_quantize(
|
def _fp8_quantize(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
block_shape: Optional[list[int]],
|
per_act_token: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Perform fp8 quantization on the inputs. If a block_shape
|
Perform fp8 quantization on the inputs. If a block_shape
|
||||||
is provided, the output will be blocked.
|
is provided, the output will be blocked.
|
||||||
"""
|
"""
|
||||||
if block_shape is None:
|
if block_shape is None:
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(
|
||||||
|
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
else:
|
else:
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||||
|
|
||||||
return A, A_scale
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _int8_quantize(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
per_act_token: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Perform int8 quantization on the inputs. If a block_shape
|
||||||
|
is provided, the output will be blocked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If weights are per-channel (per_channel_quant=True), then
|
||||||
|
# activations apply per-token quantization. Otherwise, assume
|
||||||
|
# activation tensor-wise fp8/int8 quantization, dynamic or static
|
||||||
|
if block_shape is None:
|
||||||
|
assert per_act_token, \
|
||||||
|
"int8 quantization only supports block or channel-wise"
|
||||||
|
A, A_scale = per_token_quant_int8(A)
|
||||||
|
else:
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||||
|
assert cdiv(A.size(-1), block_k) == A_scale.size(-1)
|
||||||
|
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
|
def moe_kernel_quantize_input(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
qtype: Optional[torch.dtype],
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if qtype == torch.float8_e4m3fn:
|
||||||
|
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||||
|
elif qtype == torch.int8:
|
||||||
|
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
A permutation routine that works on fp8 types.
|
A permutation routine that works on fp8 types.
|
||||||
"""
|
"""
|
||||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
if torch.is_floating_point(m) and m.dtype.itemsize == 1:
|
||||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||||
else:
|
else:
|
||||||
return m[idx, ...]
|
return m[idx, ...]
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import functools
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ from torch.nn import Module
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -434,6 +436,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
@ -458,6 +461,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"DeepGemm not supported on the current platform.")
|
"DeepGemm not supported on the current platform.")
|
||||||
|
|
||||||
|
self.fused_experts = functools.partial(
|
||||||
|
fused_experts,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm)
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -783,6 +791,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
|
def set_prepare_finalize(
|
||||||
|
self,
|
||||||
|
dp_size: int,
|
||||||
|
world_size: int,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
) -> bool:
|
||||||
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
|
TritonOrDeepGemmExperts)
|
||||||
|
|
||||||
|
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
experts = TritonOrDeepGemmExperts(
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fused_experts = mk.FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -801,10 +834,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
|
||||||
rocm_aiter_fused_experts)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -819,6 +848,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
|
rocm_aiter_fused_experts)
|
||||||
return rocm_aiter_fused_experts(
|
return rocm_aiter_fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -835,8 +866,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size)
|
block_shape=self.quant_config.weight_block_size)
|
||||||
|
elif self.use_marlin:
|
||||||
if self.use_marlin:
|
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
@ -853,28 +883,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
else:
|
||||||
return fused_experts(
|
return self.fused_experts(
|
||||||
x,
|
hidden_states=x,
|
||||||
layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=(layer.w13_weight_scale_inv
|
w1_scale=(layer.w13_weight_scale_inv
|
||||||
if self.block_quant else layer.w13_weight_scale),
|
if self.block_quant else layer.w13_weight_scale),
|
||||||
w2_scale=(layer.w2_weight_scale_inv
|
w2_scale=(layer.w2_weight_scale_inv
|
||||||
if self.block_quant else layer.w2_weight_scale),
|
if self.block_quant else layer.w2_weight_scale),
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
)
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@ -79,7 +79,6 @@ class DbrxExperts(FusedMoE):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
|
self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
|
||||||
self.tp_size)
|
self.tp_size)
|
||||||
|
|||||||
@ -31,9 +31,7 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -143,7 +141,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||||
|
),
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -154,6 +153,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
if hidden_states.dtype != torch.float16:
|
if hidden_states.dtype != torch.float16:
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@ -171,9 +171,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
# See DeepseekV2DecoderLayer for more details.
|
# See DeepseekV2DecoderLayer for more details.
|
||||||
final_hidden_states = final_hidden_states + shared_output \
|
final_hidden_states = final_hidden_states + shared_output \
|
||||||
* (1. / self.routed_scaling_factor)
|
* (1. / self.routed_scaling_factor)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = (
|
||||||
final_hidden_states)
|
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||||
|
final_hidden_states))
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
|
|||||||
@ -25,8 +25,7 @@ from transformers import Llama4TextConfig
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
@ -89,7 +88,7 @@ class Llama4MoE(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=False,
|
bias=False,
|
||||||
prefix=f"{prefix}.shared_expert",
|
prefix=f"{prefix}.shared_expert",
|
||||||
reduce_results=False, # We need to do scatter before reduce
|
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@ -102,7 +101,8 @@ class Llama4MoE(nn.Module):
|
|||||||
experts_out = routed_out + shared_out
|
experts_out = routed_out + shared_out
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
experts_out = tensor_model_parallel_all_reduce(experts_out)
|
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||||
|
experts_out)
|
||||||
|
|
||||||
return experts_out
|
return experts_out
|
||||||
|
|
||||||
|
|||||||
@ -33,9 +33,7 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -129,7 +127,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
intermediate_size=config.shared_expert_intermediate_size,
|
intermediate_size=config.shared_expert_intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.shared_expert = None
|
self.shared_expert = None
|
||||||
@ -156,7 +155,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|||||||
@ -30,9 +30,7 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -137,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
final_hidden_states = final_hidden_states
|
final_hidden_states = final_hidden_states
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(orig_shape)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|||||||
@ -158,6 +158,7 @@ class CudaPlatformBase(Platform):
|
|||||||
"currently not supported with CUDA Graphs.")
|
"currently not supported with CUDA Graphs.")
|
||||||
vllm_config.model_config.enforce_eager = True
|
vllm_config.model_config.enforce_eager = True
|
||||||
compilation_config.use_cudagraph = False
|
compilation_config.use_cudagraph = False
|
||||||
|
compilation_config.use_inductor = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(cls,
|
def get_current_memory_usage(cls,
|
||||||
|
|||||||
@ -865,8 +865,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# The zero fill is required when used with DP + EP
|
||||||
return output
|
# to ensure all ranks within a DP group compute the
|
||||||
|
# same expert outputs.
|
||||||
|
return output.fill_(0)
|
||||||
|
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
|||||||
@ -341,7 +341,8 @@ def init_worker_distributed_environment(
|
|||||||
distributed_init_method, local_rank)
|
distributed_init_method, local_rank)
|
||||||
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -265,4 +265,5 @@ def init_tpu_worker_distributed_environment(
|
|||||||
backend="gloo",
|
backend="gloo",
|
||||||
)
|
)
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|||||||
@ -390,7 +390,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
|
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
"""Return the size in bytes of a single KV cache block.
|
"""Return the size in bytes of a single KV cache block.
|
||||||
|
|||||||
@ -416,7 +416,8 @@ def init_worker_distributed_environment(
|
|||||||
backend='hccl')
|
backend='hccl')
|
||||||
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
if torch.distributed.is_initialized():
|
if torch.distributed.is_initialized():
|
||||||
torch_world_size = torch.distributed.get_world_size()
|
torch_world_size = torch.distributed.get_world_size()
|
||||||
@ -442,7 +443,8 @@ def init_worker_distributed_environment(
|
|||||||
torch.distributed.all_reduce(dummy_tensor_hpu)
|
torch.distributed.all_reduce(dummy_tensor_hpu)
|
||||||
assert dummy_tensor_hpu.item() == parallel_config.world_size
|
assert dummy_tensor_hpu.item() == parallel_config.world_size
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
|
|
||||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
|
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
|
||||||
|
|||||||
@ -76,7 +76,8 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
)
|
)
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
self.parallel_config.tensor_parallel_size,
|
self.parallel_config.tensor_parallel_size,
|
||||||
self.parallel_config.pipeline_parallel_size)
|
self.parallel_config.pipeline_parallel_size,
|
||||||
|
self.parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
# Device initialization should happen after initializing the distributed
|
# Device initialization should happen after initializing the distributed
|
||||||
# runtime.
|
# runtime.
|
||||||
|
|||||||
@ -530,7 +530,8 @@ def init_worker_distributed_environment(
|
|||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
distributed_init_method, local_rank)
|
distributed_init_method, local_rank)
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -176,7 +176,8 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
|||||||
|
|
||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.enable_expert_parallel)
|
||||||
# global all_reduce needed for overall oneccl warm up
|
# global all_reduce needed for overall oneccl warm up
|
||||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user