mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[TPU] Add kernel test for moe_pallas (#17496)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
a17cef70ea
commit
e50a1f1a9c
@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_10 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
||||
&& echo TEST_11 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
|
||||
&& echo TEST_12 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
||||
87
tests/tpu/test_moe_pallas.py
Normal file
87
tests/tpu/test_moe_pallas.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the Pallas MOE implementation.
|
||||
|
||||
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.moe_pallas import (
|
||||
fused_moe as pallas_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as torch_moe)
|
||||
# yapf: enable
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_pallas_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
import torch_xla.core.xla_model as xm
|
||||
with torch.device(xm.xla_device()):
|
||||
a = torch.randn((m, k), dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), dtype=dtype)
|
||||
|
||||
# TODO: Support ep
|
||||
if ep_size > 1:
|
||||
pytest.skip("No support for ep_size > 1 yet")
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
# Run both implementations
|
||||
torch_output = torch_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
pallas_output = pallas_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
gating_output=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
xm.mark_step()
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
pallas_output.cpu(),
|
||||
torch_output.cpu(),
|
||||
atol=2e-2,
|
||||
rtol=0,
|
||||
)
|
||||
@ -123,7 +123,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
raise NotImplementedError(
|
||||
f"Head size must be a multiple of 128, found {head_size}.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
|
||||
@ -11,7 +11,9 @@ def fused_moe(
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
renormalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -20,6 +22,7 @@ def fused_moe(
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user