mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 05:05:01 +08:00
[TPU] Increase block size and reset block shapes (#16458)
This commit is contained in:
parent
6115b11582
commit
621ca2c0ab
@ -22,7 +22,8 @@ def main():
|
|||||||
# In real workloads, `enforace_eager` should be `False`.
|
# In real workloads, `enforace_eager` should be `False`.
|
||||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||||
max_num_batched_tokens=64,
|
max_num_batched_tokens=64,
|
||||||
max_num_seqs=4)
|
max_num_seqs=4,
|
||||||
|
max_model_len=128)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
for output, answer in zip(outputs, answers):
|
for output, answer in zip(outputs, answers):
|
||||||
|
|||||||
@ -18,9 +18,9 @@ setuptools==78.1.0
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.8.0.dev20250408
|
torch==2.8.0.dev20250430
|
||||||
torchvision==0.22.0.dev20250408
|
torchvision==0.22.0.dev20250430
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
|
||||||
|
|||||||
@ -76,9 +76,9 @@ class TpuPlatform(Platform):
|
|||||||
from vllm.config import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
# For v0, the default block size is 16.
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
# TPU only supports DYNAMO_ONCE compilation level
|
# TPU only supports DYNAMO_ONCE compilation level
|
||||||
@ -101,16 +101,18 @@ class TpuPlatform(Platform):
|
|||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
from vllm.v1.attention.backends.pallas import (
|
from vllm.v1.attention.backends.pallas import (
|
||||||
PallasAttentionBackend)
|
PallasAttentionBackend)
|
||||||
|
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||||
|
vllm_config)
|
||||||
min_page_size = PallasAttentionBackend.get_min_page_size(
|
min_page_size = PallasAttentionBackend.get_min_page_size(
|
||||||
vllm_config)
|
vllm_config)
|
||||||
if min_page_size > vllm_config.cache_config.block_size:
|
if min_page_size > cache_config.block_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Increase the page size from %s to %s to make sure there's"
|
"Increase the page size from %s to %s to make sure there's"
|
||||||
"no SMEM OOM",
|
"no SMEM OOM",
|
||||||
vllm_config.cache_config.block_size,
|
cache_config.block_size,
|
||||||
min_page_size,
|
min_page_size,
|
||||||
)
|
)
|
||||||
vllm_config.cache_config.block_size = min_page_size
|
cache_config.block_size = min_page_size
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
|||||||
@ -707,6 +707,13 @@ def cdiv(a: int, b: int) -> int:
|
|||||||
return -(a // -b)
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n) -> int:
|
||||||
|
"""The next power of 2 (inclusive)"""
|
||||||
|
if n < 1:
|
||||||
|
return 1
|
||||||
|
return 1 << (n - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
def round_up(x: int, y: int) -> int:
|
def round_up(x: int, y: int) -> int:
|
||||||
return ((x + y - 1) // y) * y
|
return ((x + y - 1) // y) * y
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv, next_power_of_2
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -65,6 +65,20 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||||
return min_page_size
|
return min_page_size
|
||||||
|
|
||||||
|
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||||
|
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||||
|
# apply here is trying to split max-model-len to 16 pages which make the
|
||||||
|
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||||
|
@staticmethod
|
||||||
|
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||||
|
page_size = next_power_of_2(
|
||||||
|
vllm_config.model_config.max_model_len) // 16
|
||||||
|
if page_size <= 16:
|
||||||
|
return 16
|
||||||
|
if page_size >= 256:
|
||||||
|
return 256
|
||||||
|
return page_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PallasMetadata:
|
class PallasMetadata:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user