mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:55:32 +08:00
[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
3a0fba5cf4
commit
471fe65630
@ -22,6 +22,7 @@ MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
TENSOR_PARALLEL_SIZES = [1]
|
TENSOR_PARALLEL_SIZES = [1]
|
||||||
|
MAX_NUM_REQS = [16, 1024]
|
||||||
|
|
||||||
# TODO: Enable when CI/CD will have a multi-tpu instance
|
# TODO: Enable when CI/CD will have a multi-tpu instance
|
||||||
# TENSOR_PARALLEL_SIZES = [1, 4]
|
# TENSOR_PARALLEL_SIZES = [1, 4]
|
||||||
@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
||||||
|
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
|
||||||
def test_basic(
|
def test_basic(
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
|
max_num_seqs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = "The next numbers of the sequence " + ", ".join(
|
prompt = "The next numbers of the sequence " + ", ".join(
|
||||||
str(i) for i in range(1024)) + " are:"
|
str(i) for i in range(1024)) + " are:"
|
||||||
@ -51,9 +54,9 @@ def test_basic(
|
|||||||
# Note: max_num_batched_tokens == 1024 is needed here to
|
# Note: max_num_batched_tokens == 1024 is needed here to
|
||||||
# actually test chunked prompt
|
# actually test chunked prompt
|
||||||
max_num_batched_tokens=1024,
|
max_num_batched_tokens=1024,
|
||||||
max_model_len=8196,
|
max_model_len=8192,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
max_num_seqs=16,
|
max_num_seqs=max_num_seqs,
|
||||||
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
max_tokens)
|
max_tokens)
|
||||||
|
|||||||
@ -97,6 +97,20 @@ class TpuPlatform(Platform):
|
|||||||
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
||||||
vllm_config.model_config.dtype = torch.bfloat16
|
vllm_config.model_config.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
from vllm.v1.attention.backends.pallas import (
|
||||||
|
PallasAttentionBackend)
|
||||||
|
min_page_size = PallasAttentionBackend.get_min_page_size(
|
||||||
|
vllm_config)
|
||||||
|
if min_page_size > vllm_config.cache_config.block_size:
|
||||||
|
logger.warning(
|
||||||
|
"Increase the page size from %s to %s to make sure there's"
|
||||||
|
"no SMEM OOM",
|
||||||
|
vllm_config.cache_config.block_size,
|
||||||
|
min_page_size,
|
||||||
|
)
|
||||||
|
vllm_config.cache_config.block_size = min_page_size
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
|
|||||||
@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||||
|
|
||||||
|
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||||
|
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||||
|
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||||
|
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||||
|
@staticmethod
|
||||||
|
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||||
|
max_num_page_per_req = (1024 * 1024 // 2 //
|
||||||
|
vllm_config.scheduler_config.max_num_seqs // 4)
|
||||||
|
min_page_size = cdiv(vllm_config.model_config.max_model_len,
|
||||||
|
max_num_page_per_req)
|
||||||
|
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||||
|
return min_page_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PallasMetadata:
|
class PallasMetadata:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user