mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:55:38 +08:00
[Bugfix] Respect num-gpu-blocks-override in v1 (#19503)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
parent
af09b3f0a0
commit
c9280e6346
@ -900,3 +900,19 @@ def test_get_kv_cache_config():
|
|||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
|
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
|
||||||
mem_per_block_per_layer * 2 * 32)
|
mem_per_block_per_layer * 2 * 32)
|
||||||
|
|
||||||
|
# Test num_gpu_blocks_override
|
||||||
|
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||||
|
kv_cache_config_override_blocks = get_kv_cache_config(
|
||||||
|
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
|
||||||
|
assert kv_cache_config_override_blocks == KVCacheConfig(
|
||||||
|
num_blocks=16,
|
||||||
|
kv_cache_tensors=[
|
||||||
|
KVCacheTensor(size=mem_per_block_per_layer * 16,
|
||||||
|
shared_by=["layer_1"]),
|
||||||
|
KVCacheTensor(size=mem_per_block_per_layer * 16,
|
||||||
|
shared_by=["layer_2"]),
|
||||||
|
],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
|
||||||
|
])
|
||||||
@ -660,6 +660,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Overriding num_gpu_blocks=%d with "
|
"Overriding num_gpu_blocks=%d with "
|
||||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||||
|
num_blocks = num_gpu_blocks_override
|
||||||
return num_blocks
|
return num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user