added ProcessingInfoMixin for QwenVL series models

This commit is contained in:
Anexdeus 2025-12-20 12:29:46 +03:00
parent 36121c6db0
commit b03d1a04a8
2 changed files with 28 additions and 91 deletions

View File

@ -820,7 +820,34 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data)
class Qwen2VLProcessingInfo(BaseProcessingInfo):
class QwenVLSeriesProcessingInfoMixin:
"""
Mixin that provides get_num_mm_encoder_tokens()
and get_num_mm_connector_tokens() methods for
QwenVL series models without affecting other multi-modal models.
"""
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLProcessingInfo(QwenVLSeriesProcessingInfoMixin, BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
@ -1017,25 +1044,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
@ -1132,25 +1140,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2
@MULTIMODAL_REGISTRY.register_processor(
Qwen2VLMultiModalProcessor,

View File

@ -1185,32 +1185,6 @@ class BaseProcessingInfo:
"""
return self.ctx.get_hf_processor(**kwargs)
@abstractmethod
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the tower module of the multi-modal model
Given the number of image tokens, output the number of multi-modal encoder tokens
"""
raise NotImplementedError
@abstractmethod
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the connector module of the multi-modal model
Given the number of vision tokens, output the number of multi-modal connector tokens
"""
raise NotImplementedError
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
"""
@ -1415,32 +1389,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""Given the HF-processed data, output the metadata of each field."""
raise NotImplementedError
@abstractmethod
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the tower module of the multi-modal model
Given the number of image tokens, output the number of multi-modal encoder tokens
"""
raise NotImplementedError
@abstractmethod
def get_num_mm_connector_tokens(
self,
num_vision_tokens: int,
) -> int:
"""
Implement this function to enable LoRA support
for the connector module of the multi-modal model
Given the number of vision tokens, output the number of multi-modal connector tokens
"""
raise NotImplementedError
@abstractmethod
def _get_prompt_updates(
self,