diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 21982b01b9cc7..07b898787eba7 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_10 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ && echo TEST_11 \ - && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ + && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \ + && echo TEST_12 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py new file mode 100644 index 0000000000000..13fc8bc8fa2ed --- /dev/null +++ b/tests/tpu/test_moe_pallas.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the Pallas MOE implementation. + +Run `pytest tests/kernels/moe/test_moe_pallas.py`. +""" +import pytest +import torch + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.fused_moe.moe_pallas import ( + fused_moe as pallas_moe) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as torch_moe) +# yapf: enable +from vllm.platforms import current_platform + +if not current_platform.is_tpu(): + pytest.skip("This test needs a TPU.", allow_module_level=True) + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1] +TOP_KS = [2, 6] + + +# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 +@pytest.mark.parametrize("m", [8, 16, 64, 2048]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_pallas_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, +): + import torch_xla.core.xla_model as xm + with torch.device(xm.xla_device()): + a = torch.randn((m, k), dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 + w2 = torch.randn((e, k, n), dtype=dtype) / 10 + + score = torch.randn((m, e), dtype=dtype) + + # TODO: Support ep + if ep_size > 1: + pytest.skip("No support for ep_size > 1 yet") + else: + e_map = None + + # Run both implementations + torch_output = torch_moe( + hidden_states=a, + w1=w1, + w2=w2, + gating_output=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False, + ) + + pallas_output = pallas_moe( + hidden_states=a, + w1=w1, + w2=w2, + gating_output=score, + topk=topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False, + ) + xm.mark_step() + + # Compare outputs + torch.testing.assert_close( + pallas_output.cpu(), + torch_output.cpu(), + atol=2e-2, + rtol=0, + ) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 91d20a4e7bfc0..19642a939b481 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -123,7 +123,8 @@ class PallasAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.logits_soft_cap = logits_soft_cap if head_size % 128 != 0: - raise NotImplementedError("Head size must be a multiple of 128.") + raise NotImplementedError( + f"Head size must be a multiple of 128, found {head_size}.") if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") if sliding_window is not None: diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 0365afa10a459..8f28b64ed487c 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -11,7 +11,9 @@ def fused_moe( w2: torch.Tensor, gating_output: torch.Tensor, topk: int, - renormalize: bool, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, ) -> torch.Tensor: """ Args: @@ -20,6 +22,7 @@ def fused_moe( w2: [num_experts, hidden_size, intermediate_size] gating_output: [*, num_experts] """ + assert expert_map is None, "expert_map is not supported for pallas MoE." orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel()