mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 18:07:59 +08:00
update
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
e2caeb333a
commit
891df1db6f
@ -385,15 +385,16 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.info, "get_num_mm_encoder_tokens")
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
if self.supports_mm_lora:
|
||||
if self.supports_mm_lora: # 从init传进来就可以了,不需要model_config了
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_config = model_config.multimodal_config
|
||||
limit_per_prompt: int = max(
|
||||
[1] + \
|
||||
list(self.mm_config.limit_per_prompt.values())
|
||||
)
|
||||
# limit_per_prompt: int = max(
|
||||
# self.info.get_allowed_mm_limits().values())
|
||||
limit_per_prompt = 5
|
||||
|
||||
# For vision tower
|
||||
# max_num_batched_tokens = encoder_budget
|
||||
# max_batches = max_batches * limit_per_prompt
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name:
|
||||
get_punica_wrapper(
|
||||
@ -411,8 +412,13 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.mm_mapping.language_model[0]: self.punica_wrapper
|
||||
}
|
||||
)
|
||||
# For other
|
||||
# TODO
|
||||
# TODO Connector is not supported at the moment.
|
||||
self.mm_punica_wrapper_mapping.update(
|
||||
{
|
||||
name: None
|
||||
for name in self.mm_mapping.connector
|
||||
}
|
||||
)
|
||||
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
@ -702,15 +708,17 @@ class LoRAModelManager(AdapterModelManager):
|
||||
|
||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
||||
"""
|
||||
TODO
|
||||
Match the corresponding punica_wrapper based on module_name,
|
||||
and return None if lora is not supported for this module.
|
||||
"""
|
||||
if self.supports_mm_lora:
|
||||
for (
|
||||
prefix,
|
||||
punica_wrapper,
|
||||
) in self.mm_punica_wrapper_mapping.items():
|
||||
# Ensure matching by the longest prefix.
|
||||
sorted_prefixes = sorted(self.mm_punica_wrapper_mapping.keys(),
|
||||
key=lambda x: len(x), reverse=True)
|
||||
|
||||
for prefix in sorted_prefixes:
|
||||
if module_name.startswith(prefix):
|
||||
return punica_wrapper
|
||||
return self.mm_punica_wrapper_mapping[prefix]
|
||||
return None
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
|
||||
@ -2271,7 +2271,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_counts={
|
||||
dummy_data_modality: 1
|
||||
},
|
||||
).multi_modal_data
|
||||
)
|
||||
dummy_mm_kwargs = dummy_mm_data.multi_modal_data
|
||||
dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user