mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 03:11:23 +08:00
[FEAT] [AITER] [ROCm] integrate aiter sampling ops (#26084)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
da8dadf68b
commit
0af3d4f0df
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config.model import LogprobsMode
|
from vllm.config.model import LogprobsMode
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
@ -55,6 +56,17 @@ class TopKTopPSampler(nn.Module):
|
|||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_cpu
|
self.forward = self.forward_cpu
|
||||||
|
elif (
|
||||||
|
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||||
|
and rocm_aiter_ops.is_enabled()
|
||||||
|
):
|
||||||
|
import aiter.ops.sampling # noqa: F401
|
||||||
|
|
||||||
|
self.aiter_ops = torch.ops.aiter
|
||||||
|
logger.info_once(
|
||||||
|
"Using aiter sampler on ROCm (lazy import, sampling-only)."
|
||||||
|
)
|
||||||
|
self.forward = self.forward_hip
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
|
||||||
@ -138,6 +150,64 @@ class TopKTopPSampler(nn.Module):
|
|||||||
|
|
||||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
k: torch.Tensor | None,
|
||||||
|
p: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
|
||||||
|
if (k is None and p is None) or generators:
|
||||||
|
if generators:
|
||||||
|
logger.warning_once(
|
||||||
|
"aiter sampler does not support per-request generators; "
|
||||||
|
"falling back to PyTorch-native."
|
||||||
|
)
|
||||||
|
return self.forward_native(logits, generators, k, p)
|
||||||
|
assert self.logprobs_mode not in (
|
||||||
|
"processed_logits",
|
||||||
|
"processed_logprobs",
|
||||||
|
), "aiter sampler does not support returning logits/logprobs."
|
||||||
|
return self.aiter_sample(logits, k, p, generators), None
|
||||||
|
|
||||||
|
def aiter_sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: torch.Tensor | None,
|
||||||
|
p: torch.Tensor | None,
|
||||||
|
generators: dict[int, torch.Generator],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample from logits using aiter ops."""
|
||||||
|
use_top_k = k is not None
|
||||||
|
use_top_p = p is not None
|
||||||
|
# Joint k+p path
|
||||||
|
if use_top_p and use_top_k:
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||||
|
next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs(
|
||||||
|
probs,
|
||||||
|
None,
|
||||||
|
*_to_tensor_scalar_tuple(k),
|
||||||
|
*_to_tensor_scalar_tuple(p),
|
||||||
|
deterministic=True,
|
||||||
|
)
|
||||||
|
return next_token_ids.view(-1)
|
||||||
|
# Top-p only path
|
||||||
|
elif use_top_p:
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||||
|
next_token_ids = self.aiter_ops.top_p_sampling_from_probs(
|
||||||
|
probs, None, *_to_tensor_scalar_tuple(p), deterministic=True
|
||||||
|
)
|
||||||
|
return next_token_ids.view(-1)
|
||||||
|
# Top-k only path
|
||||||
|
elif use_top_k:
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous()
|
||||||
|
renorm_probs = self.aiter_ops.top_k_renorm_probs(
|
||||||
|
probs, *_to_tensor_scalar_tuple(k)
|
||||||
|
)
|
||||||
|
return torch.multinomial(renorm_probs, num_samples=1).view(-1)
|
||||||
|
raise RuntimeError("aiter_sample was called with no active top-k or top-p.")
|
||||||
|
|
||||||
|
|
||||||
# Note: this is a workaround for
|
# Note: this is a workaround for
|
||||||
# https://github.com/pytorch/pytorch/pull/151218
|
# https://github.com/pytorch/pytorch/pull/151218
|
||||||
@ -288,3 +358,10 @@ def flashinfer_sample(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return next_token_ids.view(-1)
|
return next_token_ids.view(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tensor_scalar_tuple(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return (x, 0)
|
||||||
|
else:
|
||||||
|
return (None, x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user