mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[V1] Use FlashInfer by default on Blackwell GPUs (#19118)
This commit is contained in:
parent
aa49f14832
commit
87360308b7
@ -229,6 +229,21 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info_once("Using Triton backend on V1 engine.")
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return ("vllm.v1.attention.backends."
|
||||||
"triton_attn.TritonAttentionBackend")
|
"triton_attn.TritonAttentionBackend")
|
||||||
|
if cls.is_device_capability(100):
|
||||||
|
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
|
||||||
|
try:
|
||||||
|
import flashinfer # noqa: F401
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlashInfer backend on V1 engine by default for "
|
||||||
|
"Blackwell (SM 10.0) GPUs.")
|
||||||
|
return ("vllm.v1.attention.backends."
|
||||||
|
"flashinfer.FlashInferBackend")
|
||||||
|
except ImportError:
|
||||||
|
logger.info_once(
|
||||||
|
"FlashInfer failed to import for V1 engine on "
|
||||||
|
"Blackwell (SM 10.0) GPUs; it is recommended to "
|
||||||
|
"install FlashInfer for better performance.")
|
||||||
|
pass
|
||||||
if cls.has_device_capability(80):
|
if cls.has_device_capability(80):
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends."
|
return ("vllm.v1.attention.backends."
|
||||||
|
|||||||
@ -228,6 +228,30 @@ class Platform:
|
|||||||
|
|
||||||
return current_capability.to_int() >= capability
|
return current_capability.to_int() >= capability
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_device_capability(
|
||||||
|
cls,
|
||||||
|
capability: Union[tuple[int, int], int],
|
||||||
|
device_id: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Test whether this platform has exactly the specified device capability.
|
||||||
|
|
||||||
|
The `capability` argument can either be:
|
||||||
|
|
||||||
|
- A tuple `(major, minor)`.
|
||||||
|
- An integer `<major><minor>`. (See
|
||||||
|
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
|
||||||
|
"""
|
||||||
|
current_capability = cls.get_device_capability(device_id=device_id)
|
||||||
|
if current_capability is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(capability, tuple):
|
||||||
|
return current_capability == capability
|
||||||
|
|
||||||
|
return current_capability.to_int() == capability
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
"""Get the name of a device."""
|
"""Get the name of a device."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user