mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:04:31 +08:00
[FEAT] [AITER] [ROCm] integrate aiter sampling ops (#26084)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
da8dadf68b
commit
0af3d4f0df
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_cpu
|
||||
elif (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
import aiter.ops.sampling # noqa: F401
|
||||
|
||||
self.aiter_ops = torch.ops.aiter
|
||||
logger.info_once(
|
||||
"Using aiter sampler on ROCm (lazy import, sampling-only)."
|
||||
)
|
||||
self.forward = self.forward_hip
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.warning_once(
|
||||
"aiter sampler does not support per-request generators; "
|
||||
"falling back to PyTorch-native."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in (
|
||||
"processed_logits",
|
||||
"processed_logprobs",
|
||||
), "aiter sampler does not support returning logits/logprobs."
|
||||
return self.aiter_sample(logits, k, p, generators), None
|
||||
|
||||
def aiter_sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from logits using aiter ops."""
|
||||
use_top_k = k is not None
|
||||
use_top_p = p is not None
|
||||
# Joint k+p path
|
||||
if use_top_p and use_top_k:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
None,
|
||||
*_to_tensor_scalar_tuple(k),
|
||||
*_to_tensor_scalar_tuple(p),
|
||||
deterministic=True,
|
||||
)
|
||||
return next_token_ids.view(-1)
|
||||
# Top-p only path
|
||||
elif use_top_p:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
|
||||
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
|
||||
)
|
||||
return next_token_ids.view(-1)
|
||||
# Top-k only path
|
||||
elif use_top_k:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||
renorm_probs = self.aiter_ops.top_k_renorm_probs(
|
||||
probs, *_to_tensor_scalar_tuple(k)
|
||||
)
|
||||
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
|
||||
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
|
||||
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@ -288,3 +358,10 @@ def flashinfer_sample(
|
||||
)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user