move mm-token-functions to model

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-21 03:34:32 +00:00
parent a3a8fc1fd0
commit 20402090b8
6 changed files with 23 additions and 24 deletions

View File

@ -158,7 +158,7 @@ class LoRAModelManager:
model_config model_config
).info ).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.mm_processor_info, "get_num_mm_encoder_tokens" self.model, "get_num_mm_encoder_tokens"
) )
if not self.supports_tower_connector_lora: if not self.supports_tower_connector_lora:
return return
@ -177,7 +177,7 @@ class LoRAModelManager:
limit_per_prompt: int = max( limit_per_prompt: int = max(
self.mm_processor_info.get_allowed_mm_limits().values() self.mm_processor_info.get_allowed_mm_limits().values()
) )
num_encoder_tokens = self.mm_processor_info.get_num_mm_encoder_tokens( num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
@ -193,8 +193,8 @@ class LoRAModelManager:
# Use wrapper for connector if present. # Use wrapper for connector if present.
if self.mm_mapping.connector: if self.mm_mapping.connector:
if hasattr(self.mm_processor_info, "get_num_mm_connector_tokens"): if hasattr(self.model, "get_num_mm_connector_tokens"):
connector_tokens = self.mm_processor_info.get_num_mm_connector_tokens( connector_tokens = self.model.get_num_mm_connector_tokens(
num_encoder_tokens num_encoder_tokens
) )
connector_punica_wrapper = get_punica_wrapper( connector_punica_wrapper = get_punica_wrapper(

View File

@ -144,16 +144,18 @@ class SupportsMultiModal(Protocol):
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
""" """
Implement this function to enable LoRA support Implement this function to enable LoRA support
for the tower module of the multi-modal model for the tower module of the multi-modal model.
Given the number of image tokens, output the number of multi-modal encoder tokens Given the number of image tokens, output the number of
multi-modal encoder tokens.
""" """
... ...
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
""" """
Implement this function to enable LoRA support Implement this function to enable LoRA support
for the connector module of the multi-modal model for the connector module of the multi-modal model.
Given the number of vision tokens, output the number of multi-modal connector tokens Given the number of vision tokens, output the number of
multi-modal connector tokens.
""" """
... ...

View File

@ -1573,7 +1573,7 @@ class Qwen2_5_VLForConditionalGeneration(
self, self,
num_image_tokens: int, num_image_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
@ -1583,7 +1583,7 @@ class Qwen2_5_VLForConditionalGeneration(
self, self,
num_vision_tokens: int, num_vision_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2

View File

@ -1495,7 +1495,7 @@ class Qwen2VLForConditionalGeneration(
self, self,
num_image_tokens: int, num_image_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
@ -1505,7 +1505,7 @@ class Qwen2VLForConditionalGeneration(
self, self,
num_vision_tokens: int, num_vision_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2

View File

@ -2096,7 +2096,7 @@ class Qwen3VLForConditionalGeneration(
self, self,
num_image_tokens: int, num_image_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
@ -2106,7 +2106,7 @@ class Qwen3VLForConditionalGeneration(
self, self,
num_vision_tokens: int, num_vision_tokens: int,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.config
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
return num_vision_tokens // merge_size**2 return num_vision_tokens // merge_size**2

View File

@ -2160,15 +2160,12 @@ class GPUModelRunner(
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model) model = cast(SupportsMultiModal, self.model)
if self.lora_manager.supports_tower_connector_lora(): if self.lora_config and self.lora_manager.supports_tower_connector_lora():
# Build LoRA mappings independently for encoder inputs # Build LoRA mappings independently for encoder inputs
# (encoder batch structure is different from main batch) # (encoder batch structure is different from main batch)
prompt_lora_mapping = [] prompt_lora_mapping = []
token_lora_mapping = [] token_lora_mapping = []
lora_requests = set() lora_requests = set()
# This implementation is a bit hacky, but it's mainly to retrieve
# the get_num_mm_*_tokens helper functions from ProcessingInfo.
mm_processor_info = self.lora_manager._adapter_manager.mm_processor_info
for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos): for req_id, (_, pos_info) in zip(encoder_req_ids, mm_hashes_pos):
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
@ -2177,7 +2174,7 @@ class GPUModelRunner(
# Prefer pos_info.is_embed to count actual MM embedding tokens. # Prefer pos_info.is_embed to count actual MM embedding tokens.
# pos_info.length may overcount (e.g., special tokens in Qwen-VL). # pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None. # Fall back to length if is_embed is None.
num_tokens = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined] num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.get_num_embeds pos_info.get_num_embeds
) )
prompt_lora_mapping.append(lora_id) prompt_lora_mapping.append(lora_id)
@ -2196,13 +2193,13 @@ class GPUModelRunner(
) )
self.lora_manager.set_active_adapters(lora_requests, lora_mapping) self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
if hasattr(mm_processor_info, "get_num_mm_connector_tokens"): if hasattr(self.model, "get_num_mm_connector_tokens"):
num_post_op_tokens = [] num_post_op_tokens = []
for _, pos_info in mm_hashes_pos: for _, pos_info in mm_hashes_pos:
mm_token_count = mm_processor_info.get_num_mm_encoder_tokens( # type: ignore[attr-defined] mm_token_count = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length pos_info.length
) )
post_op_count = mm_processor_info.get_num_mm_connector_tokens( # type: ignore[attr-defined] post_op_count = self.model.get_num_mm_connector_tokens( # type: ignore[attr-defined]
mm_token_count mm_token_count
) )
num_post_op_tokens.append(post_op_count) num_post_op_tokens.append(post_op_count)