diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 11501bc5d92f..3b204a8f9905 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,9 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250430 -torchvision==0.22.0.dev20250430 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250518 +torchvision==0.22.0.dev20250518 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250518-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64eae..31efe16d1c27 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -50,8 +50,7 @@ if is_rocm_aiter_moe_enabled(): else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): - # the iterative moe implementation is used until the moe_pallas is fixed - from .moe_torch_iterative import fused_moe as fused_moe_pallas + from .moe_pallas import fused_moe as fused_moe_pallas else: fused_moe_pallas = None # type: ignore logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 8f28b64ed487..babeb97308a9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -2,7 +2,23 @@ import torch import torch.nn.functional as F -from torch_xla.experimental.custom_kernel import _histogram + + +def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: + """ + Compute the histogram of a int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ + assert input.dtype == torch.int32, "input must be of torch.int32 dtype." + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence: torch.Tensor, + values_to_search: torch.Tensor) -> torch.Tensor: + return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) + + bin_edges = torch.linspace(min, max, max - min + 1, + dtype=input.dtype).to(input.device) + return searchsorted(bin_edges, input).to(torch.int32) def fused_moe( @@ -61,7 +77,7 @@ def fused_moe( x = torch.ops.xla.gmm(x, w2, group_sizes) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - x = x * topk_weights.unsqueeze_(dim=-1) + x = x * topk_weights.unsqueeze(dim=-1) x = x.sum(dim=-2) x = x.reshape(orig_shape) return x