mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:55:49 +08:00
[Feature] Allow configuring FlashInfer workspace size (#28269)
Signed-off-by: Max Hu <hyoung2991@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
e5f599d4d1
commit
412e153df5
@ -159,6 +159,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
|
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
|
||||||
|
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
|
||||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||||
@ -1237,6 +1238,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
|
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
|
||||||
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
|
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
|
||||||
),
|
),
|
||||||
|
# Control the workspace buffer size for the FlashInfer backend.
|
||||||
|
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
|
||||||
|
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
|
||||||
|
),
|
||||||
# Control the maximum number of tokens per expert supported by the
|
# Control the maximum number of tokens per expert supported by the
|
||||||
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
||||||
# the blockscale tensor of activations NVFP4 Quantization.
|
# the blockscale tensor of activations NVFP4 Quantization.
|
||||||
@ -1583,6 +1588,7 @@ def compute_hash() -> str:
|
|||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
|
||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
|
||||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
|
||||||
|
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE",
|
||||||
"VLLM_USE_CUDNN_PREFILL",
|
"VLLM_USE_CUDNN_PREFILL",
|
||||||
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
|
||||||
"VLLM_USE_TRTLLM_ATTENTION",
|
"VLLM_USE_TRTLLM_ATTENTION",
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
|||||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||||
from flashinfer.utils import FP4Tensor
|
from flashinfer.utils import FP4Tensor
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@ -55,7 +56,6 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@ -70,7 +70,7 @@ def _get_trtllm_gen_workspace_buffer():
|
|||||||
global trtllm_gen_workspace_buffer
|
global trtllm_gen_workspace_buffer
|
||||||
if trtllm_gen_workspace_buffer is None:
|
if trtllm_gen_workspace_buffer is None:
|
||||||
trtllm_gen_workspace_buffer = torch.zeros(
|
trtllm_gen_workspace_buffer = torch.zeros(
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
|
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
|
||||||
)
|
)
|
||||||
return trtllm_gen_workspace_buffer
|
return trtllm_gen_workspace_buffer
|
||||||
|
|
||||||
@ -414,7 +414,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
||||||
self._workspace_buffer = torch.zeros(
|
self._workspace_buffer = torch.zeros(
|
||||||
|
|||||||
@ -196,8 +196,8 @@ from typing import ClassVar, Generic, TypeVar
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm import envs
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
@ -453,12 +453,6 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
|
||||||
# Chosen to be the same as sglang:
|
|
||||||
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
@ -590,7 +584,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
if self._use_fi_prefill:
|
if self._use_fi_prefill:
|
||||||
self._workspace_buffer = torch.empty(
|
self._workspace_buffer = torch.empty(
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
|
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
|
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
|
||||||
@ -602,7 +598,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
if self._use_trtllm_ragged_prefill:
|
if self._use_trtllm_ragged_prefill:
|
||||||
self._workspace_buffer = torch.empty(
|
self._workspace_buffer = torch.empty(
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
|
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._use_cudnn_prefill:
|
if self._use_cudnn_prefill:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user