Revert the mixin changes

This commit is contained in:
Anexdeus 2025-12-20 13:31:53 +03:00
parent b03d1a04a8
commit d525556a25
2 changed files with 42 additions and 28 deletions

View File

@ -820,34 +820,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data)
class QwenVLSeriesProcessingInfoMixin:
"""
Mixin that provides get_num_mm_encoder_tokens()
and get_num_mm_connector_tokens() methods for
QwenVL series models without affecting other multi-modal models.
"""
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLProcessingInfo(QwenVLSeriesProcessingInfoMixin, BaseProcessingInfo):
class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
@ -1131,6 +1104,25 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
for modality in ("image", "video")
]
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,

View File

@ -1412,6 +1412,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
raise NotImplementedError
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the tower module of the multi-modal model
Given the number of image tokens, output the number of multi-modal encoder tokens
"""
raise NotImplementedError
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the connector module of the multi-modal model
Given the number of vision tokens, output the number of multi-modal connector tokens
"""
raise NotImplementedError
def _bind_and_group_updates(
self,
prompt_updates: Sequence[PromptUpdate],