Merge pull request #12 from Anexdeus/mlm-full-lora-support

Extended SupportsMultiModal
This commit is contained in:
B-201 2025-12-21 11:02:44 +08:00 committed by GitHub
commit a3a8fc1fd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 38 deletions

View File

@ -141,6 +141,22 @@ class SupportsMultiModal(Protocol):
"""
...
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
"""
Implement this function to enable LoRA support
for the tower module of the multi-modal model
Given the number of image tokens, output the number of multi-modal encoder tokens
"""
...
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
"""
Implement this function to enable LoRA support
for the connector module of the multi-modal model
Given the number of vision tokens, output the number of multi-modal connector tokens
"""
...
@overload
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...

View File

@ -1007,25 +1007,6 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
for modality in ("image", "video")
]
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,
@ -1587,3 +1568,22 @@ class Qwen2_5_VLForConditionalGeneration(
connector="visual.merger.",
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2

View File

@ -1017,25 +1017,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
@ -1510,6 +1491,25 @@ class Qwen2VLForConditionalGeneration(
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
pass

View File

@ -2091,3 +2091,22 @@ class Qwen3VLForConditionalGeneration(
connector=["visual.merger", "visual.deepstack_merger_list"],
tower_model="visual.",
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2