[Bugfix] adding chunking mechanism to fused_moe to handle large inputs (#6029)

This commit is contained in:
Avshalom Manevich 2024-07-02 00:08:29 +03:00 committed by GitHub
parent dec6fc6f3b
commit 12a59959ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 48 deletions

View File

@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("e", [8, 64])

View File

@ -32,6 +32,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@ -248,6 +249,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": "VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -8,6 +8,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
M, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
if M > 65536:
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
raise ValueError("MoE kernel does not support more than 65536 tokens, " CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
f"but got {M}") M = min(num_tokens, CHUNK_SIZE)
if override_config: if override_config:
config = override_config config = override_config
@ -455,18 +455,43 @@ def fused_experts(hidden_states: torch.Tensor,
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16 compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16) if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states, if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
if tokens_in_chunk < CHUNK_SIZE:
# will only happen in the last chunk
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
invoke_fused_moe_kernel(curr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
w1_scale, w1_scale,
topk_weights, curr_topk_weights,
topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
@ -483,8 +508,8 @@ def fused_experts(hidden_states: torch.Tensor,
intermediate_cache3, intermediate_cache3,
a2_scale, a2_scale,
w2_scale, w2_scale,
topk_weights, curr_topk_weights,
topk_ids, curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
@ -494,12 +519,10 @@ def fused_experts(hidden_states: torch.Tensor,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8)
if inplace: torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
out=hidden_states) out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return out_hidden_states
dim=1)
def fused_moe( def fused_moe(