[CI/Build] Fix multimodal tests (#22491)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-08 15:31:19 +08:00 committed by GitHub
parent 808a7b69df
commit 1712543df6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 15 deletions

View File

@ -845,7 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool: def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache.""" """Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache() return self.input_preprocessor.mm_registry.reset_processor_cache(
self.model_config)
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for all devices."""

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn import torch.nn as nn
@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]):
return self.processor(info, dummy_inputs_builder, cache=cache) return self.processor(info, dummy_inputs_builder, cache=cache)
# Make sure a different cache is used for each model config
# NOTE: ModelConfig is not hashable so it cannot be passed directly
@lru_cache(maxsize=1)
def _get_processor_cache(model_id: str, capacity_gb: int):
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
class MultiModalRegistry: class MultiModalRegistry:
""" """
A registry that dispatches data processing according to the model. A registry that dispatches data processing according to the model.
@ -95,22 +103,15 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module, self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]() _ProcessorFactories]()
self._processor_cache: Optional[ProcessingCache] = None
def _get_processor_cache(self, model_config: "ModelConfig"): def _get_processor_cache(self, model_config: "ModelConfig"):
model_id = model_config.model
capacity_gb = model_config.mm_processor_cache_gb capacity_gb = model_config.mm_processor_cache_gb
if capacity_gb is None: return _get_processor_cache(model_id, capacity_gb)
return None # Overrides `disable_cache` argument
if self._processor_cache is None: def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
self._processor_cache = ProcessingCache(capacity_gb)
return self._processor_cache
def reset_processor_cache(self) -> bool:
"""Reset the multi-modal processing cache.""" """Reset the multi-modal processing cache."""
if self._processor_cache: if processor_cache := self._get_processor_cache(model_config):
self._processor_cache.reset() processor_cache.reset()
return True # Success return True # Success

View File

@ -566,7 +566,7 @@ class AsyncLLM(EngineClient):
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None: async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache() self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset() self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()

View File

@ -271,7 +271,7 @@ class LLMEngine:
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_mm_cache(self): def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache() self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset() self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()