Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-16 16:30:20 +00:00
parent bdac2b5d17
commit 57917818a4

View File

@ -593,11 +593,13 @@ class GPUModelRunner(
)
# Multimodal LoRA support
if self.supports_mm_inputs:
self.enable_tower_connector_lora = False
if self.supports_mm_inputs and self.lora_config:
self.info = self.mm_registry.create_processor(self.model_config).info
self.supports_mm_lora = hasattr(self.info, "get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
self.enable_tower_connector_lora = (
hasattr(self.info, "get_num_mm_encoder_tokens")
and self.lora_config.enable_tower_connector_lora
)
# Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination.
@ -2148,7 +2150,7 @@ class GPUModelRunner(
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
if self.lora_config and self.supports_mm_lora:
if self.enable_tower_connector_lora:
# Build LoRA mappings independently for encoder inputs
# (encoder batch structure is different from main batch)
prompt_lora_mapping = []
@ -2371,37 +2373,6 @@ class GPUModelRunner(
return mm_embeds, is_mm_embed
def _extract_encoder_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> dict[str, torch.Tensor]:
"""Extract encoder inputs for encoder-decoder models.
This method extracts multimodal input features from scheduled encoder
inputs and formats them for the encoder-decoder model forward pass.
"""
# Batch the multi-modal inputs using the helper method.
mm_kwargs, _, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
if not mm_kwargs:
return {}
# Group MM kwargs by modality and extract features
model = cast(SupportsMultiModal, self.model)
encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
):
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# input_features=...)
encoder_features.update(mm_kwargs_group)
return encoder_features
def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):