[BugFix] Fix --disable-log-stats in V1 server mode (#17600)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-05-07 21:08:15 -07:00 committed by GitHub
parent 66ab3b13c9
commit 3d13ca0e24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 5 deletions

View File

@ -120,8 +120,9 @@ class AsyncLLM(EngineClient):
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats, log_stats=self.log_stats,
) )
for stat_logger in self.stat_loggers[0]: if self.stat_loggers:
stat_logger.log_engine_initialized() for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
try: try:
# Start output handler eagerly if we are in the asyncio eventloop. # Start output handler eagerly if we are in the asyncio eventloop.

View File

@ -442,9 +442,10 @@ class MPClient(EngineCoreClient):
logger.info("Core engine process %d ready.", eng_id) logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id) identities.discard(eng_id)
# Setup KV cache config with initialization state from # Setup KV cache config with initialization state from
# engine core process. # engine core process. Sum values from all engines in DP case.
self.vllm_config.cache_config.num_gpu_blocks = message_dict[ num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0
'num_gpu_blocks'] num_gpu_blocks += message_dict['num_gpu_blocks']
self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
def _init_core_engines( def _init_core_engines(
self, self,