[TPU]Fix KV cache sharing tests (#19371)

This commit is contained in:
Siyuan Liu 2025-06-09 15:38:15 -07:00 committed by GitHub
parent 31f58be96a
commit 7d44c469fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import unittest.mock as mock
import pytest
@ -17,24 +16,8 @@ from vllm.v1.worker.tpu_model_runner import (
TPUModelRunner, _get_padded_num_reqs_with_upper_limit,
_get_padded_token_len, _get_req_paddings, _get_token_paddings)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
"sys.modules", {
"torch_xla": mock.MagicMock(),
"torch_xla.core.xla_model": mock.MagicMock(),
"torch_xla.runtime": mock.MagicMock(),
})
torch_xla_patcher.start()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher = mock.patch(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
pallas_attention_backend_patcher.start()
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
@ -60,18 +43,19 @@ def model_runner():
cache_config=cache_config,
scheduler_config=scheduler_config,
)
return vllm_config
def get_model_runner(vllm_config):
device = "xla:0" # Mocking TPU device
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
return TPUModelRunner(vllm_config, device)
return TPUModelRunner(vllm_config, device)
@pytest.fixture(autouse=True, scope="session")
def cleanup_patches():
yield
torch_xla_patcher.stop()
pallas_attention_backend_patcher.stop()
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
vllm_config = get_vllm_config()
return get_model_runner(vllm_config)
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
@ -370,12 +354,14 @@ def test_get_req_paddings():
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(
model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
@ -399,13 +385,14 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
Attention(
@ -428,12 +415,13 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner):
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
with pytest.raises(ValueError, match=error_msg):
vllm_config = model_runner.vllm_config
with pytest.raises(ValueError, match=error_msg), \
set_current_vllm_config(vllm_config):
fwd_context = {
# initialization below will fail because target layer is invalid;
# the target layer needs to come before layer 1
@ -457,11 +445,10 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_without_kv_sharing(model_runner):
def test_init_kv_cache_without_kv_sharing():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
@ -482,33 +469,38 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
# suppress var not used error
assert fwd_context is not None
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_config.model_config.max_model_len = 1_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 2
assert len(model_runner.shared_kv_cache_layers) == 0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
# page size for each layer KV can be calculated as
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 1310720
# max_context_len = available_memory / (page_size / block_size) / num_caches
# max_context_len = 5GB / (512KB / 128) / 2 = 655360
assert max_context_len == 655360
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
# this will only allocate 2 block worth of memory (2 * 512kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = (
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
model_runner.initialize_kv_cache(kv_cache_config)
@ -524,11 +516,10 @@ def test_init_kv_cache_without_kv_sharing(model_runner):
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
@pytest.mark.skip(reason="Test is broken on TPU when it's added.")
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
def test_init_kv_cache_with_kv_sharing_valid():
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = model_runner.vllm_config
vllm_config = get_vllm_config()
with set_current_vllm_config(vllm_config):
fwd_context = {
layer_0:
@ -552,33 +543,34 @@ def test_init_kv_cache_with_kv_sharing_valid(model_runner):
# Set high context length to test max context length estimation
vllm_config.model_config.max_model_len = 3_000_000
vllm_ctx = vllm_config.compilation_config.static_forward_context
model_runner = get_model_runner(vllm_config)
kv_cache_spec = model_runner.get_kv_cache_spec()
assert len(kv_cache_spec) == 1
assert layer_0 in kv_cache_spec
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# page size for layer 0's kv_cache_spec is 512KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
# max context len with KV sharing should be 2x as large as without
assert max_context_len == 2 * 1310720
assert max_context_len == (2 * 655360)
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
# this will only allocate 1 block worth of memory (512kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_config.kv_cache_tensors[0].size =\
kv_cache_spec[layer_0].page_size_bytes
model_runner.initialize_kv_cache(kv_cache_config)