mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[V1] Allocate kv_cache with stride order for V1 (#18775)
Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
parent
d58f9c7f7a
commit
32ce3cf7c9
@ -1,7 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.attention import Attention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig, VllmConfig)
|
SchedulerConfig, VllmConfig)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
NUM_BLOCKS = 10
|
||||||
|
|
||||||
|
|
||||||
def initialize_kv_cache(runner: GPUModelRunner):
|
def initialize_kv_cache(runner: GPUModelRunner):
|
||||||
"""
|
"""
|
||||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||||
"""
|
"""
|
||||||
|
attn_spec = FullAttentionSpec(
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||||
|
runner.parallel_config),
|
||||||
|
head_size=runner.model_config.get_head_size(),
|
||||||
|
dtype=runner.kv_cache_dtype,
|
||||||
|
use_mla=False,
|
||||||
|
)
|
||||||
|
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||||
kv_cache_config = KVCacheConfig(
|
kv_cache_config = KVCacheConfig(
|
||||||
num_blocks=10,
|
num_blocks=NUM_BLOCKS,
|
||||||
tensors={
|
tensors={
|
||||||
"layer.0": KVCacheTensor(size=1024),
|
"layer.0": KVCacheTensor(size=tensor_size),
|
||||||
},
|
},
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(
|
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||||
layer_names=["layer.0"],
|
|
||||||
kv_cache_spec=FullAttentionSpec(
|
|
||||||
block_size=16,
|
|
||||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
|
||||||
runner.parallel_config),
|
|
||||||
head_size=runner.model_config.get_head_size(),
|
|
||||||
dtype=runner.kv_cache_dtype,
|
|
||||||
use_mla=False,
|
|
||||||
))
|
|
||||||
])
|
])
|
||||||
runner.kv_cache_config = kv_cache_config
|
runner.kv_cache_config = kv_cache_config
|
||||||
runner.input_batch = InputBatch(
|
runner.input_batch = InputBatch(
|
||||||
@ -65,7 +71,7 @@ def model_runner():
|
|||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=16,
|
block_size=BLOCK_SIZE,
|
||||||
gpu_memory_utilization=0.9,
|
gpu_memory_utilization=0.9,
|
||||||
swap_space=0,
|
swap_space=0,
|
||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
@ -77,6 +83,10 @@ def model_runner():
|
|||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
)
|
)
|
||||||
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
head_size = model_config.get_head_size()
|
||||||
|
vllm_config.compilation_config.static_forward_context[
|
||||||
|
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
runner = GPUModelRunner(vllm_config, device)
|
runner = GPUModelRunner(vllm_config, device)
|
||||||
@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[1])
|
assert _is_req_added(model_runner, req_ids[1])
|
||||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||||
|
# This test checks if GPUModelRunner initializes correctly when an attention
|
||||||
|
# backend enforces a non-default KV cache stride order.
|
||||||
|
n_heads = model_runner.model_config.get_num_kv_heads(
|
||||||
|
model_runner.parallel_config)
|
||||||
|
expected_kv_cache_shape = [
|
||||||
|
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
|
||||||
|
model_runner.model_config.get_head_size()
|
||||||
|
]
|
||||||
|
# TODO mla test
|
||||||
|
default_stride = list(range(5))
|
||||||
|
# Permutation that gets you back to expected kv shape
|
||||||
|
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
|
||||||
|
|
||||||
|
def rnd_stride_order():
|
||||||
|
return rnd_stride
|
||||||
|
|
||||||
|
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||||
|
for attn_backend in model_runner.attn_backends:
|
||||||
|
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||||
|
rnd_stride_order)
|
||||||
|
|
||||||
|
model_runner.attn_backends = []
|
||||||
|
model_runner.attn_metadata_builders = []
|
||||||
|
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||||
|
|
||||||
|
# Shape is unchanged, but layout may differ
|
||||||
|
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||||
|
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||||
|
if default_stride == rnd_stride:
|
||||||
|
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||||
|
else:
|
||||||
|
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||||
|
|||||||
@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_blocks, kv_cache_spec.block_size,
|
num_blocks, kv_cache_spec.block_size,
|
||||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
try:
|
||||||
dtype=dtype,
|
kv_cache_stride_order = self.attn_backends[
|
||||||
device=self.device)
|
i].get_kv_cache_stride_order()
|
||||||
|
assert len(kv_cache_stride_order) == len(
|
||||||
|
kv_cache_shape)
|
||||||
|
except (AttributeError, NotImplementedError):
|
||||||
|
kv_cache_stride_order = tuple(
|
||||||
|
range(len(kv_cache_shape)))
|
||||||
|
# The allocation respects the backend-defined stride order
|
||||||
|
# to ensure the semantic remains consistent for each
|
||||||
|
# backend. We first obtain the generic kv cache shape and
|
||||||
|
# then permute it according to the stride order which could
|
||||||
|
# result in a non-contiguous tensor.
|
||||||
|
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||||
|
for i in kv_cache_stride_order)
|
||||||
|
# Maintain original KV shape view.
|
||||||
|
inv_order = [
|
||||||
|
kv_cache_stride_order.index(i)
|
||||||
|
for i in range(len(kv_cache_stride_order))
|
||||||
|
]
|
||||||
|
kv_caches[layer_name] = torch.zeros(
|
||||||
|
kv_cache_shape, dtype=dtype,
|
||||||
|
device=self.device).permute(*inv_order)
|
||||||
else:
|
else:
|
||||||
# TODO: add new branches when introducing more types of
|
# TODO: add new branches when introducing more types of
|
||||||
# KV cache specs.
|
# KV cache specs.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user