Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-12-24 04:43:05 +00:00
parent ac8afb6ae8
commit 57d7267fee

View File

@ -2167,57 +2167,48 @@ class GPUModelRunner(
prompt_lora_mapping = []
token_lora_mapping = []
lora_requests = set()
encoder_token_counts = []
for req_id, pos_info in mm_lora_refs:
req_idx = self.input_batch.req_id_to_index[req_id]
lora_id = int(self.input_batch.request_lora_mapping[req_idx])
# Prefer pos_info.is_embed to count actual MM embedding tokens.
# pos_info.length may overcount (e.g., special tokens in Qwen-VL).
# Fall back to length if is_embed is None.
# Prefer pos_info.get_num_embeds to count precise MM embedding tokens.
num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.get_num_embeds
)
prompt_lora_mapping.append(lora_id)
token_lora_mapping.extend([lora_id] * num_tokens)
encoder_token_counts.append(num_tokens)
if lora_id > 0:
lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
if lora_request is not None:
lora_requests.add(lora_request)
lora_mapping = LoRAMapping(
# Set tower adapter mapping
tower_mapping = LoRAMapping(
tuple(token_lora_mapping),
tuple(prompt_lora_mapping),
is_prefill=True,
type=LoRAMappingType.TOWER,
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
self.lora_manager.set_active_adapters(lora_requests, tower_mapping)
if hasattr(self.model, "get_num_mm_connector_tokens"):
num_post_op_tokens = []
for _, pos_info in mm_lora_refs:
mm_token_count = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
pos_info.length
)
post_op_count = self.model.get_num_mm_connector_tokens( # type: ignore[attr-defined]
mm_token_count
)
num_post_op_tokens.append(post_op_count)
post_op_counts = [
self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined]
for num_tokens in encoder_token_counts
]
last_mapping = self.lora_manager._adapter_manager._last_mapping
assert last_mapping is not None
lora_ids = np.array(
last_mapping.prompt_mapping,
dtype=np.int32,
connector_token_mapping = np.repeat(
np.array(prompt_lora_mapping, dtype=np.int32),
np.array(post_op_counts, dtype=np.int32),
)
post_op_counts_np = np.array(num_post_op_tokens, dtype=np.int32)
new_token_indices = lora_ids.repeat(post_op_counts_np)
connector_mapping = LoRAMapping(
index_mapping=tuple(new_token_indices.tolist()),
prompt_mapping=last_mapping.prompt_mapping,
is_prefill=last_mapping.is_prefill,
index_mapping=tuple(connector_token_mapping.tolist()),
prompt_mapping=tuple(prompt_lora_mapping),
is_prefill=True,
type=LoRAMappingType.CONNECTOR,
)