mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-26 10:57:05 +08:00
update mm filter
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
98debc2424
commit
7db0d5990a
@ -386,19 +386,33 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.supports_mm_lora = False
|
self.supports_mm_lora = False
|
||||||
if self.supports_mm_lora:
|
if self.supports_mm_lora:
|
||||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
|
self.mm_config = model_config.multimodal_config
|
||||||
|
limit_per_prompt: int = max(
|
||||||
|
[1] + \
|
||||||
|
list(self.mm_config.limit_per_prompt.values())
|
||||||
|
)
|
||||||
|
|
||||||
|
# For vision tower
|
||||||
self.mm_punica_wrapper_mapping = {
|
self.mm_punica_wrapper_mapping = {
|
||||||
name:
|
name:
|
||||||
get_punica_wrapper(
|
get_punica_wrapper(
|
||||||
self.info.get_num_mm_encoder_tokens(
|
self.info.get_num_mm_encoder_tokens(
|
||||||
max_num_batched_tokens),
|
max_num_batched_tokens),
|
||||||
max_batches=self.max_num_seqs, # TODO
|
max_batches=self.max_num_seqs * limit_per_prompt,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_loras=self.lora_config.max_loras,
|
max_loras=self.lora_config.max_loras,
|
||||||
)
|
)
|
||||||
for name in self.mm_mapping.tower_model
|
for name in self.mm_mapping.tower_model
|
||||||
}
|
}
|
||||||
self.mm_punica_wrapper_mapping[
|
# For language model
|
||||||
self.mm_mapping.language_model[0]] = self.punica_wrapper
|
self.mm_punica_wrapper_mapping.update(
|
||||||
|
{
|
||||||
|
self.mm_mapping.language_model[0]: self.punica_wrapper
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# For other
|
||||||
|
# TODO
|
||||||
|
|
||||||
self.is_pooling_model = is_pooling_model(self.model)
|
self.is_pooling_model = is_pooling_model(self.model)
|
||||||
self.packed_modules: dict[str, list[str]] = {}
|
self.packed_modules: dict[str, list[str]] = {}
|
||||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||||
@ -539,9 +553,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
continue
|
continue
|
||||||
# A temporary approach for multimodal models to support LoRA
|
# A temporary approach for multimodal models to support LoRA
|
||||||
# TODO: Remove this restriction
|
# TODO: Remove this restriction
|
||||||
if (self._filter_unsupported_mm_module(module_name)
|
if self._filter_unsupported_mm_module(module_name):
|
||||||
and not self.supports_mm_lora
|
|
||||||
or self._get_mm_punica_wrapper(module_name) is None):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Regarding multimodal models, vLLM currently only supports "
|
"Regarding multimodal models, vLLM currently only supports "
|
||||||
"adding LoRA to language model, %s will be ignored.",
|
"adding LoRA to language model, %s will be ignored.",
|
||||||
@ -678,10 +690,13 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
be filtered out.
|
be filtered out.
|
||||||
"""
|
"""
|
||||||
if self.supports_mm:
|
if self.supports_mm:
|
||||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
|
||||||
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
if self.supports_mm_lora:
|
||||||
return any(
|
|
||||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
return self._get_mm_punica_wrapper(module_name) is None
|
||||||
|
else:
|
||||||
|
return any(
|
||||||
|
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
||||||
|
|||||||
@ -157,10 +157,11 @@ class LoRAModelRunnerMixin:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||||
num_scheduled_tokens: np.ndarray):
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
is_mm_input: bool = False):
|
||||||
with self.maybe_setup_dummy_loras(
|
with self.maybe_setup_dummy_loras(
|
||||||
lora_config), self.maybe_select_dummy_loras(
|
lora_config), self.maybe_select_dummy_loras(
|
||||||
lora_config, num_scheduled_tokens):
|
lora_config, num_scheduled_tokens, is_mm_input):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user