mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 23:34:32 +08:00
[bugfix][DCP] fix block_size of hash in DCP prefix caching (#26296)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
720d3cd0f0
commit
606b00e80f
@ -1411,6 +1411,7 @@ def create_scheduler_with_priority(
|
|||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -129,6 +129,7 @@ def create_scheduler(
|
|||||||
return scheduler_cls(
|
return scheduler_cls(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
|
block_size=block_size,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -138,6 +138,7 @@ def create_scheduler(
|
|||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
structured_output_manager: StructuredOutputManager,
|
structured_output_manager: StructuredOutputManager,
|
||||||
|
block_size: int,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
include_finished_set: bool = False,
|
include_finished_set: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
@ -101,15 +102,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||||
|
|
||||||
self.block_size = self.cache_config.block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
# Note(hc): The scheduler’s block_size must be multiplied
|
|
||||||
# by dcp_world_size, since block hashes are computed on the
|
|
||||||
# original full token sequence at a granularity of
|
|
||||||
# original_block_size × dcp_world_size.
|
|
||||||
if self.dcp_world_size > 1:
|
|
||||||
self.block_size *= self.dcp_world_size
|
|
||||||
|
|
||||||
# req_id -> Request
|
# req_id -> Request
|
||||||
self.requests: dict[str, Request] = {}
|
self.requests: dict[str, Request] = {}
|
||||||
|
|||||||
@ -142,12 +142,18 @@ class EngineCore:
|
|||||||
logger.info("Disabling chunked prefill for model without KVCache")
|
logger.info("Disabling chunked prefill for model without KVCache")
|
||||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
|
||||||
|
scheduler_block_size = (
|
||||||
|
vllm_config.cache_config.block_size
|
||||||
|
* vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
self.scheduler: SchedulerInterface = Scheduler(
|
self.scheduler: SchedulerInterface = Scheduler(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
structured_output_manager=self.structured_output_manager,
|
structured_output_manager=self.structured_output_manager,
|
||||||
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
|
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
|
block_size=scheduler_block_size,
|
||||||
)
|
)
|
||||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||||
if self.scheduler.connector is not None: # type: ignore
|
if self.scheduler.connector is not None: # type: ignore
|
||||||
@ -177,14 +183,13 @@ class EngineCore:
|
|||||||
self.vllm_config.cache_config.enable_prefix_caching
|
self.vllm_config.cache_config.enable_prefix_caching
|
||||||
or self.scheduler.get_kv_connector() is not None
|
or self.scheduler.get_kv_connector() is not None
|
||||||
):
|
):
|
||||||
block_size = vllm_config.cache_config.block_size
|
|
||||||
caching_hash_fn = get_hash_fn_by_name(
|
caching_hash_fn = get_hash_fn_by_name(
|
||||||
vllm_config.cache_config.prefix_caching_hash_algo
|
vllm_config.cache_config.prefix_caching_hash_algo
|
||||||
)
|
)
|
||||||
init_none_hash(caching_hash_fn)
|
init_none_hash(caching_hash_fn)
|
||||||
|
|
||||||
self.request_block_hasher = get_request_block_hasher(
|
self.request_block_hasher = get_request_block_hasher(
|
||||||
block_size, caching_hash_fn
|
scheduler_block_size, caching_hash_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.step_fn = (
|
self.step_fn = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user