[TPU] Add kernel test for moe_pallas (#17496)

Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-06 20:59:57 -04:00 committed by GitHub
parent a17cef70ea
commit e50a1f1a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 3 deletions

View File

@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_10 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
&& echo TEST_11 \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
&& echo TEST_12 \
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling

View File

@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Pallas MOE implementation.
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.fused_moe.moe_pallas import (
fused_moe as pallas_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as torch_moe)
# yapf: enable
from vllm.platforms import current_platform
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_pallas_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
):
import torch_xla.core.xla_model as xm
with torch.device(xm.xla_device()):
a = torch.randn((m, k), dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w2 = torch.randn((e, k, n), dtype=dtype) / 10
score = torch.randn((m, e), dtype=dtype)
# TODO: Support ep
if ep_size > 1:
pytest.skip("No support for ep_size > 1 yet")
else:
e_map = None
# Run both implementations
torch_output = torch_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
pallas_output = pallas_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
xm.mark_step()
# Compare outputs
torch.testing.assert_close(
pallas_output.cpu(),
torch_output.cpu(),
atol=2e-2,
rtol=0,
)

View File

@ -123,7 +123,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
raise NotImplementedError(
f"Head size must be a multiple of 128, found {head_size}.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:

View File

@ -11,7 +11,9 @@ def fused_moe(
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
global_num_experts: int,
expert_map: torch.Tensor = None,
renormalize: bool = False,
) -> torch.Tensor:
"""
Args:
@ -20,6 +22,7 @@ def fused_moe(
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()