diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 3b204a8f9905..edc8b2a45667 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250518 -torchvision==0.22.0.dev20250518 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250529 +torchvision==0.22.0.dev20250529 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250529-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 13fc8bc8fa2e..19df22f78039 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -26,7 +26,7 @@ TOP_KS = [2, 6] # The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 @pytest.mark.parametrize("m", [8, 16, 64, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 53945999288d..9d8bd62c6969 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -67,15 +67,10 @@ def fused_moe( token_indices = token_indices[topk_argsort_indices] group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) - # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout - # from HF Transformers. - w1 = w1.transpose(1, 2) - w2 = w2.transpose(1, 2) - x = hidden_states[token_indices] - x = torch.ops.xla.gmm(x, w1, group_sizes) + x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True) x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] - x = torch.ops.xla.gmm(x, w2, group_sizes) + x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) x = x * topk_weights.unsqueeze(dim=-1)