mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:57:10 +08:00
[TPU] Add kernel test for moe_pallas (#17496)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
a17cef70ea
commit
e50a1f1a9c
@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_10 \
|
&& echo TEST_10 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
|
||||||
&& echo TEST_11 \
|
&& echo TEST_11 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
|
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
|
||||||
|
&& echo TEST_12 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
|
|||||||
87
tests/tpu/test_moe_pallas.py
Normal file
87
tests/tpu/test_moe_pallas.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests for the Pallas MOE implementation.
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_pallas import (
|
||||||
|
fused_moe as pallas_moe)
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||||
|
fused_moe as torch_moe)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_tpu():
|
||||||
|
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||||
|
|
||||||
|
NUM_EXPERTS = [8, 64]
|
||||||
|
EP_SIZE = [1]
|
||||||
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
|
||||||
|
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||||
|
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||||
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
def test_pallas_moe(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
ep_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
a = torch.randn((m, k), dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||||
|
|
||||||
|
score = torch.randn((m, e), dtype=dtype)
|
||||||
|
|
||||||
|
# TODO: Support ep
|
||||||
|
if ep_size > 1:
|
||||||
|
pytest.skip("No support for ep_size > 1 yet")
|
||||||
|
else:
|
||||||
|
e_map = None
|
||||||
|
|
||||||
|
# Run both implementations
|
||||||
|
torch_output = torch_moe(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
gating_output=score,
|
||||||
|
topk=topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
pallas_output = pallas_moe(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
gating_output=score,
|
||||||
|
topk=topk,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=e_map,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
torch.testing.assert_close(
|
||||||
|
pallas_output.cpu(),
|
||||||
|
torch_output.cpu(),
|
||||||
|
atol=2e-2,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
@ -123,7 +123,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
if head_size % 128 != 0:
|
if head_size % 128 != 0:
|
||||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
raise NotImplementedError(
|
||||||
|
f"Head size must be a multiple of 128, found {head_size}.")
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
raise NotImplementedError("Alibi slopes is not supported.")
|
raise NotImplementedError("Alibi slopes is not supported.")
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
|
|||||||
@ -11,7 +11,9 @@ def fused_moe(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
global_num_experts: int,
|
||||||
|
expert_map: torch.Tensor = None,
|
||||||
|
renormalize: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -20,6 +22,7 @@ def fused_moe(
|
|||||||
w2: [num_experts, hidden_size, intermediate_size]
|
w2: [num_experts, hidden_size, intermediate_size]
|
||||||
gating_output: [*, num_experts]
|
gating_output: [*, num_experts]
|
||||||
"""
|
"""
|
||||||
|
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_size = hidden_states.shape[-1]
|
hidden_size = hidden_states.shape[-1]
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user