diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 02ea658b7f20e..c6c7e924175f7 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -7,6 +7,7 @@ import torch.nn as nn from packaging import version from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform @@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module): self.forward = self.forward_native else: self.forward = self.forward_cpu + elif ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and rocm_aiter_ops.is_enabled() + ): + import aiter.ops.sampling # noqa: F401 + + self.aiter_ops = torch.ops.aiter + logger.info_once( + "Using aiter sampler on ROCm (lazy import, sampling-only)." + ) + self.forward = self.forward_hip else: self.forward = self.forward_native @@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module): return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return + def forward_hip( + self, + logits: torch.Tensor, + generators: dict[int, torch.Generator], + k: torch.Tensor | None, + p: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Optimized ROCm/aiter path (same structure as forward_cuda).""" + if (k is None and p is None) or generators: + if generators: + logger.warning_once( + "aiter sampler does not support per-request generators; " + "falling back to PyTorch-native." + ) + return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + "processed_logits", + "processed_logprobs", + ), "aiter sampler does not support returning logits/logprobs." + return self.aiter_sample(logits, k, p, generators), None + + def aiter_sample( + self, + logits: torch.Tensor, + k: torch.Tensor | None, + p: torch.Tensor | None, + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + """Sample from logits using aiter ops.""" + use_top_k = k is not None + use_top_p = p is not None + # Joint k+p path + if use_top_p and use_top_k: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs( + probs, + None, + *_to_tensor_scalar_tuple(k), + *_to_tensor_scalar_tuple(p), + deterministic=True, + ) + return next_token_ids.view(-1) + # Top-p only path + elif use_top_p: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + next_token_ids = self.aiter_ops.top_p_sampling_from_probs( + probs, None, *_to_tensor_scalar_tuple(p), deterministic=True + ) + return next_token_ids.view(-1) + # Top-k only path + elif use_top_k: + probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() + renorm_probs = self.aiter_ops.top_k_renorm_probs( + probs, *_to_tensor_scalar_tuple(k) + ) + return torch.multinomial(renorm_probs, num_samples=1).view(-1) + raise RuntimeError("aiter_sample was called with no active top-k or top-p.") + # Note: this is a workaround for # https://github.com/pytorch/pytorch/pull/151218 @@ -288,3 +358,10 @@ def flashinfer_sample( ) return next_token_ids.view(-1) + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x)