[FEAT] [AITER] [ROCm] integrate aiter sampling ops (#26084)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-19 01:28:34 +08:00 committed by GitHub
parent da8dadf68b
commit 0af3d4f0df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from packaging import version
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
self.forward = self.forward_native
else:
self.forward = self.forward_cpu
elif (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and rocm_aiter_ops.is_enabled()
):
import aiter.ops.sampling # noqa: F401
self.aiter_ops = torch.ops.aiter
logger.info_once(
"Using aiter sampler on ROCm (lazy import, sampling-only)."
)
self.forward = self.forward_hip
else:
self.forward = self.forward_native
@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
def forward_hip(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
if (k is None and p is None) or generators:
if generators:
logger.warning_once(
"aiter sampler does not support per-request generators; "
"falling back to PyTorch-native."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
"processed_logits",
"processed_logprobs",
), "aiter sampler does not support returning logits/logprobs."
return self.aiter_sample(logits, k, p, generators), None
def aiter_sample(
self,
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from logits using aiter ops."""
use_top_k = k is not None
use_top_p = p is not None
# Joint k+p path
if use_top_p and use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
probs,
None,
*_to_tensor_scalar_tuple(k),
*_to_tensor_scalar_tuple(p),
deterministic=True,
)
return next_token_ids.view(-1)
# Top-p only path
elif use_top_p:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
)
return next_token_ids.view(-1)
# Top-k only path
elif use_top_k:
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
renorm_probs = self.aiter_ops.top_k_renorm_probs(
probs, *_to_tensor_scalar_tuple(k)
)
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@ -288,3 +358,10 @@ def flashinfer_sample(
)
return next_token_ids.view(-1)
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)