mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 05:17:03 +08:00
extended SupportsMultiModal
This commit is contained in:
parent
cd32aeadfa
commit
c6831e793d
@ -154,9 +154,8 @@ class LoRAModelManager:
|
|||||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||||
|
|
||||||
if self.lora_config.enable_tower_connector_lora:
|
if self.lora_config.enable_tower_connector_lora:
|
||||||
self.info = MULTIMODAL_REGISTRY.create_processor(model_config).info
|
|
||||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||||
self.info, "get_num_mm_encoder_tokens"
|
self.model, "get_num_mm_encoder_tokens"
|
||||||
)
|
)
|
||||||
if not self.supports_tower_connector_lora:
|
if not self.supports_tower_connector_lora:
|
||||||
return
|
return
|
||||||
@ -172,8 +171,8 @@ class LoRAModelManager:
|
|||||||
vllm_config.scheduler_config,
|
vllm_config.scheduler_config,
|
||||||
MULTIMODAL_REGISTRY,
|
MULTIMODAL_REGISTRY,
|
||||||
)
|
)
|
||||||
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values())
|
limit_per_prompt: int = max(self.model.get_allowed_mm_limits().values())
|
||||||
num_encoder_tokens = self.info.get_num_mm_encoder_tokens(
|
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
|
||||||
mm_budget.get_encoder_budget()
|
mm_budget.get_encoder_budget()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -189,8 +188,8 @@ class LoRAModelManager:
|
|||||||
|
|
||||||
# Use wrapper for connector if present.
|
# Use wrapper for connector if present.
|
||||||
if self.mm_mapping.connector:
|
if self.mm_mapping.connector:
|
||||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
if hasattr(self.model, "get_num_mm_connector_tokens"):
|
||||||
connector_tokens = self.info.get_num_mm_connector_tokens(
|
connector_tokens = self.model.get_num_mm_connector_tokens(
|
||||||
num_encoder_tokens
|
num_encoder_tokens
|
||||||
)
|
)
|
||||||
connector_punica_wrapper = get_punica_wrapper(
|
connector_punica_wrapper = get_punica_wrapper(
|
||||||
|
|||||||
@ -141,6 +141,22 @@ class SupportsMultiModal(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
|
||||||
|
"""
|
||||||
|
Implement this function to enable LoRA support
|
||||||
|
for the tower module of the multi-modal model
|
||||||
|
Given the number of image tokens, output the number of multi-modal encoder tokens
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
|
||||||
|
"""
|
||||||
|
Implement this function to enable LoRA support
|
||||||
|
for the connector module of the multi-modal model
|
||||||
|
Given the number of vision tokens, output the number of multi-modal connector tokens
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
|
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
|
||||||
|
|
||||||
|
|||||||
@ -1568,3 +1568,39 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
connector="visual.merger.",
|
connector="visual.merger.",
|
||||||
tower_model="visual.",
|
tower_model="visual.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(
|
||||||
|
self,
|
||||||
|
num_image_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
return num_image_tokens * merge_size**2
|
||||||
|
|
||||||
|
def get_num_mm_connector_tokens(
|
||||||
|
self,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
return num_vision_tokens // merge_size**2
|
||||||
|
|
||||||
|
def get_allowed_mm_limits(self) -> Mapping[str, int]:
|
||||||
|
"""Return the maximum allowed number of items for each modality."""
|
||||||
|
supported_mm_limits = self.get_supported_mm_limits()
|
||||||
|
mm_config = self.ctx.get_mm_config()
|
||||||
|
|
||||||
|
allowed_limits = dict[str, int]()
|
||||||
|
for modality, supported_limit in supported_mm_limits.items():
|
||||||
|
user_limit = mm_config.get_limit_per_prompt(modality)
|
||||||
|
|
||||||
|
allowed_limits[modality] = (
|
||||||
|
user_limit
|
||||||
|
if supported_limit is None
|
||||||
|
else min(user_limit, supported_limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
return allowed_limits
|
||||||
|
|||||||
@ -1104,25 +1104,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
|
|||||||
for modality in ("image", "video")
|
for modality in ("image", "video")
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_num_mm_encoder_tokens(
|
|
||||||
self,
|
|
||||||
num_image_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
merge_size = vision_config.spatial_merge_size
|
|
||||||
|
|
||||||
return num_image_tokens * merge_size**2
|
|
||||||
|
|
||||||
def get_num_mm_connector_tokens(
|
|
||||||
self,
|
|
||||||
num_vision_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
merge_size = vision_config.spatial_merge_size
|
|
||||||
return num_vision_tokens // merge_size**2
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
@ -1510,6 +1491,42 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
tower_model="visual.",
|
tower_model="visual.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_allowed_mm_limits(self) -> Mapping[str, int]:
|
||||||
|
"""Return the maximum allowed number of items for each modality."""
|
||||||
|
supported_mm_limits = self.get_supported_mm_limits()
|
||||||
|
mm_config = self.ctx.get_mm_config()
|
||||||
|
|
||||||
|
allowed_limits = dict[str, int]()
|
||||||
|
for modality, supported_limit in supported_mm_limits.items():
|
||||||
|
user_limit = mm_config.get_limit_per_prompt(modality)
|
||||||
|
|
||||||
|
allowed_limits[modality] = (
|
||||||
|
user_limit
|
||||||
|
if supported_limit is None
|
||||||
|
else min(user_limit, supported_limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
return allowed_limits
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(
|
||||||
|
self,
|
||||||
|
num_image_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
return num_image_tokens * merge_size**2
|
||||||
|
|
||||||
|
def get_num_mm_connector_tokens(
|
||||||
|
self,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
return num_vision_tokens // merge_size**2
|
||||||
|
|
||||||
|
|
||||||
class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
|
class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -2091,3 +2091,39 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
connector=["visual.merger", "visual.deepstack_merger_list"],
|
connector=["visual.merger", "visual.deepstack_merger_list"],
|
||||||
tower_model="visual.",
|
tower_model="visual.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(
|
||||||
|
self,
|
||||||
|
num_image_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
return num_image_tokens * merge_size**2
|
||||||
|
|
||||||
|
def get_num_mm_connector_tokens(
|
||||||
|
self,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
return num_vision_tokens // merge_size**2
|
||||||
|
|
||||||
|
def get_allowed_mm_limits(self) -> Mapping[str, int]:
|
||||||
|
"""Return the maximum allowed number of items for each modality."""
|
||||||
|
supported_mm_limits = self.get_supported_mm_limits()
|
||||||
|
mm_config = self.ctx.get_mm_config()
|
||||||
|
|
||||||
|
allowed_limits = dict[str, int]()
|
||||||
|
for modality, supported_limit in supported_mm_limits.items():
|
||||||
|
user_limit = mm_config.get_limit_per_prompt(modality)
|
||||||
|
|
||||||
|
allowed_limits[modality] = (
|
||||||
|
user_limit
|
||||||
|
if supported_limit is None
|
||||||
|
else min(user_limit, supported_limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
return allowed_limits
|
||||||
|
|||||||
@ -1420,28 +1420,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_num_mm_encoder_tokens(
|
|
||||||
self,
|
|
||||||
num_image_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Implement this function to enable LoRA support
|
|
||||||
for the tower module of the multi-modal model
|
|
||||||
Given the number of image tokens, output the number of multi-modal encoder tokens
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_num_mm_connector_tokens(
|
|
||||||
self,
|
|
||||||
num_vision_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Implement this function to enable LoRA support
|
|
||||||
for the connector module of the multi-modal model
|
|
||||||
Given the number of vision tokens, output the number of multi-modal connector tokens
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _bind_and_group_updates(
|
def _bind_and_group_updates(
|
||||||
self,
|
self,
|
||||||
prompt_updates: Sequence[PromptUpdate],
|
prompt_updates: Sequence[PromptUpdate],
|
||||||
|
|||||||
@ -593,9 +593,9 @@ class GPUModelRunner(
|
|||||||
# Multimodal LoRA support
|
# Multimodal LoRA support
|
||||||
self.enable_tower_connector_lora = False
|
self.enable_tower_connector_lora = False
|
||||||
if self.supports_mm_inputs and self.lora_config:
|
if self.supports_mm_inputs and self.lora_config:
|
||||||
self.info = self.mm_registry.create_processor(self.model_config).info
|
self.mm_model_cls = self.mm_registry._get_model_cls(model_config)
|
||||||
self.enable_tower_connector_lora = (
|
self.enable_tower_connector_lora = (
|
||||||
hasattr(self.info, "get_num_mm_encoder_tokens")
|
hasattr(self.mm_model_cls, "get_num_mm_encoder_tokens")
|
||||||
and self.lora_config.enable_tower_connector_lora
|
and self.lora_config.enable_tower_connector_lora
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2183,7 +2183,7 @@ class GPUModelRunner(
|
|||||||
# Prefer pos_info.is_embed to count actual MM embedding tokens.
|
# Prefer pos_info.is_embed to count actual MM embedding tokens.
|
||||||
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
|
||||||
# Fall back to length if is_embed is None.
|
# Fall back to length if is_embed is None.
|
||||||
num_tokens = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
num_tokens = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||||
pos_info.get_num_embeds
|
pos_info.get_num_embeds
|
||||||
)
|
)
|
||||||
prompt_lora_mapping.append(lora_id)
|
prompt_lora_mapping.append(lora_id)
|
||||||
@ -2202,13 +2202,13 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||||
|
|
||||||
if hasattr(self.info, "get_num_mm_connector_tokens"):
|
if hasattr(model, "get_num_mm_connector_tokens"):
|
||||||
num_post_op_tokens = []
|
num_post_op_tokens = []
|
||||||
for _, pos_info in mm_hashes_pos:
|
for _, pos_info in mm_hashes_pos:
|
||||||
mm_token_count = self.info.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
mm_token_count = model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
|
||||||
pos_info.length
|
pos_info.length
|
||||||
)
|
)
|
||||||
post_op_count = self.info.get_num_mm_connector_tokens( # type: ignore[attr-defined]
|
post_op_count = model.get_num_mm_connector_tokens( # type: ignore[attr-defined]
|
||||||
mm_token_count
|
mm_token_count
|
||||||
)
|
)
|
||||||
num_post_op_tokens.append(post_op_count)
|
num_post_op_tokens.append(post_op_count)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user