mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:45:21 +08:00
Guard FlashInfer sampler using the same check as FlashInfer attention backend (#29415)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
0808eb813b
commit
9eec282cb5
@ -33,6 +33,16 @@ class TopKTopPSampler(nn.Module):
|
|||||||
and current_platform.is_cuda()
|
and current_platform.is_cuda()
|
||||||
):
|
):
|
||||||
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||||
|
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||||
|
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
assert capability is not None
|
||||||
|
if not FlashInferBackend.supports_compute_capability(capability):
|
||||||
|
capability_str = capability.as_version_str()
|
||||||
|
raise RuntimeError(
|
||||||
|
"FlashInfer does not support compute capability "
|
||||||
|
f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1."
|
||||||
|
)
|
||||||
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer for top-p & top-k sampling.",
|
"Using FlashInfer for top-p & top-k sampling.",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user