Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-10-11 02:07:28 +00:00
parent e2caeb333a
commit 891df1db6f
2 changed files with 24 additions and 14 deletions

View File

@ -385,15 +385,16 @@ class LoRAModelManager(AdapterModelManager):
self.info, "get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
if self.supports_mm_lora:
if self.supports_mm_lora: # 从init传进来就可以了不需要model_config了
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_config = model_config.multimodal_config
limit_per_prompt: int = max(
[1] + \
list(self.mm_config.limit_per_prompt.values())
)
# limit_per_prompt: int = max(
# self.info.get_allowed_mm_limits().values())
limit_per_prompt = 5
# For vision tower
# max_num_batched_tokens = encoder_budget
# max_batches = max_batches * limit_per_prompt
self.mm_punica_wrapper_mapping = {
name:
get_punica_wrapper(
@ -411,8 +412,13 @@ class LoRAModelManager(AdapterModelManager):
self.mm_mapping.language_model[0]: self.punica_wrapper
}
)
# For other
# TODO
# TODO Connector is not supported at the moment.
self.mm_punica_wrapper_mapping.update(
{
name: None
for name in self.mm_mapping.connector
}
)
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
@ -702,15 +708,17 @@ class LoRAModelManager(AdapterModelManager):
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
"""
TODO
Match the corresponding punica_wrapper based on module_name,
and return None if lora is not supported for this module.
"""
if self.supports_mm_lora:
for (
prefix,
punica_wrapper,
) in self.mm_punica_wrapper_mapping.items():
# Ensure matching by the longest prefix.
sorted_prefixes = sorted(self.mm_punica_wrapper_mapping.keys(),
key=lambda x: len(x), reverse=True)
for prefix in sorted_prefixes:
if module_name.startswith(prefix):
return punica_wrapper
return self.mm_punica_wrapper_mapping[prefix]
return None
def _register_packed_modules(self, module_full_name: str) -> None:

View File

@ -2271,7 +2271,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mm_counts={
dummy_data_modality: 1
},
).multi_modal_data
)
dummy_mm_kwargs = dummy_mm_data.multi_modal_data
dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items,