mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:57:10 +08:00
[TPU] remove transpose ops in moe kernel (#18923)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
a521ef06e5
commit
a1cc9f33a3
@ -18,9 +18,9 @@ setuptools==78.1.0
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.8.0.dev20250518
|
||||
torchvision==0.22.0.dev20250518
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.8.0.dev20250529
|
||||
torchvision==0.22.0.dev20250529
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ TOP_KS = [2, 6]
|
||||
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
|
||||
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
|
||||
@ -67,15 +67,10 @@ def fused_moe(
|
||||
token_indices = token_indices[topk_argsort_indices]
|
||||
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
|
||||
|
||||
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
|
||||
# from HF Transformers.
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
x = hidden_states[token_indices]
|
||||
x = torch.ops.xla.gmm(x, w1, group_sizes)
|
||||
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
|
||||
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
|
||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||
|
||||
x = x * topk_weights.unsqueeze(dim=-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user