mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:15:00 +08:00
[CI/Build] Fix multimodal tests (#22491)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
808a7b69df
commit
1712543df6
@ -845,7 +845,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
def reset_mm_cache(self) -> bool:
|
def reset_mm_cache(self) -> bool:
|
||||||
"""Reset the multi-modal cache."""
|
"""Reset the multi-modal cache."""
|
||||||
return self.input_preprocessor.mm_registry.reset_processor_cache()
|
return self.input_preprocessor.mm_registry.reset_processor_cache(
|
||||||
|
self.model_config)
|
||||||
|
|
||||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||||
"""Reset prefix cache for all devices."""
|
"""Reset prefix cache for all devices."""
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]):
|
|||||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||||
|
|
||||||
|
|
||||||
|
# Make sure a different cache is used for each model config
|
||||||
|
# NOTE: ModelConfig is not hashable so it cannot be passed directly
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _get_processor_cache(model_id: str, capacity_gb: int):
|
||||||
|
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
|
||||||
|
|
||||||
|
|
||||||
class MultiModalRegistry:
|
class MultiModalRegistry:
|
||||||
"""
|
"""
|
||||||
A registry that dispatches data processing according to the model.
|
A registry that dispatches data processing according to the model.
|
||||||
@ -95,22 +103,15 @@ class MultiModalRegistry:
|
|||||||
self._processor_factories = ClassRegistry[nn.Module,
|
self._processor_factories = ClassRegistry[nn.Module,
|
||||||
_ProcessorFactories]()
|
_ProcessorFactories]()
|
||||||
|
|
||||||
self._processor_cache: Optional[ProcessingCache] = None
|
|
||||||
|
|
||||||
def _get_processor_cache(self, model_config: "ModelConfig"):
|
def _get_processor_cache(self, model_config: "ModelConfig"):
|
||||||
|
model_id = model_config.model
|
||||||
capacity_gb = model_config.mm_processor_cache_gb
|
capacity_gb = model_config.mm_processor_cache_gb
|
||||||
if capacity_gb is None:
|
return _get_processor_cache(model_id, capacity_gb)
|
||||||
return None # Overrides `disable_cache` argument
|
|
||||||
|
|
||||||
if self._processor_cache is None:
|
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
|
||||||
self._processor_cache = ProcessingCache(capacity_gb)
|
|
||||||
|
|
||||||
return self._processor_cache
|
|
||||||
|
|
||||||
def reset_processor_cache(self) -> bool:
|
|
||||||
"""Reset the multi-modal processing cache."""
|
"""Reset the multi-modal processing cache."""
|
||||||
if self._processor_cache:
|
if processor_cache := self._get_processor_cache(model_config):
|
||||||
self._processor_cache.reset()
|
processor_cache.reset()
|
||||||
|
|
||||||
return True # Success
|
return True # Success
|
||||||
|
|
||||||
|
|||||||
@ -566,7 +566,7 @@ class AsyncLLM(EngineClient):
|
|||||||
await self.engine_core.profile_async(False)
|
await self.engine_core.profile_async(False)
|
||||||
|
|
||||||
async def reset_mm_cache(self) -> None:
|
async def reset_mm_cache(self) -> None:
|
||||||
self.processor.mm_registry.reset_processor_cache()
|
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||||
self.processor.mm_input_cache_client.reset()
|
self.processor.mm_input_cache_client.reset()
|
||||||
await self.engine_core.reset_mm_cache_async()
|
await self.engine_core.reset_mm_cache_async()
|
||||||
|
|
||||||
|
|||||||
@ -271,7 +271,7 @@ class LLMEngine:
|
|||||||
self.engine_core.profile(False)
|
self.engine_core.profile(False)
|
||||||
|
|
||||||
def reset_mm_cache(self):
|
def reset_mm_cache(self):
|
||||||
self.processor.mm_registry.reset_processor_cache()
|
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||||
self.processor.mm_input_cache_client.reset()
|
self.processor.mm_input_cache_client.reset()
|
||||||
self.engine_core.reset_mm_cache()
|
self.engine_core.reset_mm_cache()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user