mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:45:01 +08:00
[TPU] Re-enable the Pallas MoE kernel (#18025)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
23baa2180b
commit
3b17ea26e4
@ -18,9 +18,9 @@ setuptools==78.1.0
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.8.0.dev20250430
|
torch==2.8.0.dev20250518
|
||||||
torchvision==0.22.0.dev20250430
|
torchvision==0.22.0.dev20250518
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
|
||||||
|
|||||||
@ -50,8 +50,7 @@ if is_rocm_aiter_moe_enabled():
|
|||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
|
||||||
else:
|
else:
|
||||||
fused_moe_pallas = None # type: ignore
|
fused_moe_pallas = None # type: ignore
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -2,7 +2,23 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_xla.experimental.custom_kernel import _histogram
|
|
||||||
|
|
||||||
|
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the histogram of a int32 tensor. The bin edges are defined by the
|
||||||
|
min and max values, with step = 1.
|
||||||
|
"""
|
||||||
|
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
|
||||||
|
assert min <= max, "min must be less than or equal to max."
|
||||||
|
|
||||||
|
def searchsorted(sorted_sequence: torch.Tensor,
|
||||||
|
values_to_search: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
|
||||||
|
|
||||||
|
bin_edges = torch.linspace(min, max, max - min + 1,
|
||||||
|
dtype=input.dtype).to(input.device)
|
||||||
|
return searchsorted(bin_edges, input).to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
def fused_moe(
|
||||||
@ -61,7 +77,7 @@ def fused_moe(
|
|||||||
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
x = torch.ops.xla.gmm(x, w2, group_sizes)
|
||||||
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
||||||
|
|
||||||
x = x * topk_weights.unsqueeze_(dim=-1)
|
x = x * topk_weights.unsqueeze(dim=-1)
|
||||||
x = x.sum(dim=-2)
|
x = x.sum(dim=-2)
|
||||||
x = x.reshape(orig_shape)
|
x = x.reshape(orig_shape)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user