mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[TPU] Re-enable the Pallas MoE kernel (#18025)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
23baa2180b
commit
3b17ea26e4
@ -18,9 +18,9 @@ setuptools==78.1.0
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.8.0.dev20250430
|
||||
torchvision==0.22.0.dev20250430
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.8.0.dev20250518
|
||||
torchvision==0.22.0.dev20250518
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
|
||||
@ -50,8 +50,7 @@ if is_rocm_aiter_moe_enabled():
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -2,7 +2,23 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch_xla.experimental.custom_kernel import _histogram
|
||||
|
||||
|
||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||
"""
|
||||
Compute the histogram of a int32 tensor. The bin edges are defined by the
|
||||
min and max values, with step = 1.
|
||||
"""
|
||||
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
|
||||
assert min <= max, "min must be less than or equal to max."
|
||||
|
||||
def searchsorted(sorted_sequence: torch.Tensor,
|
||||
values_to_search: torch.Tensor) -> torch.Tensor:
|
||||
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
|
||||
|
||||
bin_edges = torch.linspace(min, max, max - min + 1,
|
||||
dtype=input.dtype).to(input.device)
|
||||
return searchsorted(bin_edges, input).to(torch.int32)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
@ -61,7 +77,7 @@ def fused_moe(
|
||||
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||
|
||||
x = x * topk_weights.unsqueeze_(dim=-1)
|
||||
x = x * topk_weights.unsqueeze(dim=-1)
|
||||
x = x.sum(dim=-2)
|
||||
x = x.reshape(orig_shape)
|
||||
return x
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user