add mm_punica_warpper

Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
bk-201 2025-05-22 00:31:56 +08:00
parent 23baa2180b
commit e7026a7c50
9 changed files with 200 additions and 36 deletions

View File

@ -27,6 +27,10 @@ argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
# via isoduration
async-timeout==5.0.1
# via
# aiohttp
# redis
attrs==24.2.0
# via
# aiohttp
@ -129,6 +133,11 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
exceptiongroup==1.3.0
# via
# anyio
# hypothesis
# pytest
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
@ -640,7 +649,6 @@ setuptools==77.0.3
# via
# mamba-ssm
# pytablewriter
# torch
# triton
shellingham==1.5.4
# via typer
@ -700,8 +708,13 @@ tokenizers==0.21.1
# via
# -r requirements/test.in
# transformers
toml==0.10.2
# via datamodel-code-generator
tomli==2.2.1
# via schemathesis
# via
# black
# pytest
# schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.7.0+cu128
@ -775,13 +788,18 @@ types-python-dateutil==2.9.0.20241206
# via arrow
typing-extensions==4.12.2
# via
# anyio
# black
# exceptiongroup
# huggingface-hub
# librosa
# mistral-common
# mteb
# multidict
# pqdm
# pydantic
# pydantic-core
# rich
# torch
# typer
tzdata==2024.2

View File

@ -77,6 +77,7 @@ def _not_fully_sharded_can_replace(can_replace):
@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
is_mm_input: bool = False
class BaseLayerWithLoRA(nn.Module):
@ -410,6 +411,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# Store original shape for later reshaping
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
@ -424,6 +428,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
if not current_platform.can_update_inplace():
output = lora_output
# Restore original shape if it was flattened
if original_shape is not None:
output = output.reshape(original_shape)
return output
@property

View File

@ -17,14 +17,14 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig
from vllm.config import LoRAConfig, ModelConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
@ -33,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -311,6 +312,7 @@ class LoRAModelManager(AdapterModelManager):
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
model_config: Optional[ModelConfig],
device: torch.device,
):
"""Create a LoRAModelManager and adapter for a given model.
@ -357,6 +359,30 @@ class LoRAModelManager(AdapterModelManager):
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
# For v0 compatibility
if model_config is not None:
self.mm_registry = MULTIMODAL_REGISTRY
self.info = self.mm_registry.create_processor(
model_config, disable_cache=True).info
self.supports_mm_lora = self.supports_mm and hasattr(
self.info, "get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
if self.supports_mm_lora:
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
self.mm_punica_wrapper_mapping = {
name:
get_punica_wrapper(
self.info.get_num_mm_encoder_tokens(
max_num_batched_tokens),
max_batches=self.max_num_seqs, # TODO
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]] = self.punica_wrapper
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
@ -452,14 +478,35 @@ class LoRAModelManager(AdapterModelManager):
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# update lora states
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
if not self.supports_mm_lora:
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
elif mapping.is_mm_input:
self.mm_punica_wrapper_mapping[
self.mm_mapping.tower_model[0]].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
else:
self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
@ -476,7 +523,9 @@ class LoRAModelManager(AdapterModelManager):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
if (self._filter_unsupported_mm_module(module_name)
and not self.supports_mm_lora
or self._get_mm_punica_wrapper(module_name) is None):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
@ -519,7 +568,11 @@ class LoRAModelManager(AdapterModelManager):
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
if self.supports_mm_lora:
new_module.set_mapping(
self._get_mm_punica_wrapper(module_name))
else:
new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
@ -615,6 +668,19 @@ class LoRAModelManager(AdapterModelManager):
[module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
"""
TODO
"""
if self.supports_mm_lora:
for (
prefix,
punica_wrapper,
) in self.mm_punica_wrapper_mapping.items():
if module_name.startswith(prefix):
return punica_wrapper
return None
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
@ -713,9 +779,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def __init__(self, model: nn.Module, max_num_seqs: int,
max_num_batched_tokens: int, vocab_size: int,
lora_config: LoRAConfig, device: torch.device):
lora_config: LoRAConfig, model_config: ModelConfig,
device: torch.device):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, device)
vocab_size, lora_config, model_config, device)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters: LoRALRUCache = LoRALRUCache(
@ -785,6 +852,7 @@ def create_lora_manager(
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
model_config: ModelConfig,
device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
@ -797,6 +865,7 @@ def create_lora_manager(
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
model_config=model_config,
device=device,
**kwargs)
return lora_manager

View File

@ -10,7 +10,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import LoRAConfig
from vllm.config import LoRAConfig, ModelConfig
from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
@ -200,6 +200,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
def create_lora_manager(
self,
model: torch.nn.Module,
model_config: Optional[ModelConfig] = None,
) -> Any:
lora_manager = create_lora_manager(
model,
@ -209,6 +210,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
model_config=model_config,
)
self._adapter_manager = lora_manager
return lora_manager.model

View File

@ -279,6 +279,15 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height=image_processor.size["longest_edge"],
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
scale_factor = hf_config.scale_factor
return num_image_tokens * scale_factor**2
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
):

View File

@ -962,6 +962,16 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def get_num_mm_encoder_tokens(
self,
num_image_tokens: int,
) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
merge_size = vision_config.spatial_merge_size
return num_image_tokens * merge_size**2
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from abc import ABC
from collections.abc import Mapping
from dataclasses import dataclass, field
@ -44,6 +45,7 @@ class DummyDecoderData(NamedTuple):
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargs
multi_modal_placeholders: MultiModalPlaceholderDict
multi_modal_token_ids: list[int]
_I = TypeVar("_I", bound=BaseProcessingInfo)
@ -249,6 +251,7 @@ class MultiModalProfiler(Generic[_I]):
str(self._get_mm_num_tokens(mm_inputs)),
)
multi_modal_token_ids = copy.deepcopy(prompt_token_ids)
if total_len < seq_len:
prompt_token_ids.extend([0] * (seq_len - total_len))
@ -256,6 +259,7 @@ class MultiModalProfiler(Generic[_I]):
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"],
multi_modal_token_ids=multi_modal_token_ids,
)
def get_mm_max_tokens(

View File

@ -263,6 +263,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Multimodal LoRA support
if self.is_multimodal_model:
self.info = self.mm_registry.create_processor(
self.model_config, disable_cache=True).info
self.supports_mm_lora = hasattr(self.info,
"get_num_mm_encoder_tokens")
else:
self.supports_mm_lora = False
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
@ -892,12 +901,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_tokens = list[int]()
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_tokens.append(req_state.mm_positions[mm_input_id].length)
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -911,6 +922,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
if self.lora_config and self.supports_mm_lora:
mm_tokens = [
self.info.get_num_mm_encoder_tokens(num_token)
for num_token in mm_tokens
]
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
self.set_active_loras(self.input_batch,
num_scheduled_tokens,
is_mm_input=True)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
@ -1826,22 +1847,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_budget, max_num_mm_items, dummy_data_modality)
# Create dummy batch of multimodal inputs.
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
dummy_mm_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={
dummy_data_modality: 1
},
).multi_modal_data
mm_counts={dummy_data_modality: 1},
)
dummy_mm_kwargs = dummy_mm_data.multi_modal_data
dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids
max_num_mm_items = 1 # temporary
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
[dummy_mm_kwargs] * max_num_mm_items) # ???
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
if self.supports_mm_lora:
num_scheduled_tokens_list = [
self.info.get_num_mm_encoder_tokens(
len(dummy_mm_token_ids))
] * max_num_mm_items
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
lora_config = self.lora_config
else:
num_scheduled_tokens = None
lora_config = None
with self.maybe_dummy_run_with_lora(lora_config,
num_scheduled_tokens,
is_mm_input=True):
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,

View File

@ -50,11 +50,13 @@ class LoRAModelRunnerMixin:
model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
return self.lora_manager.create_lora_manager(model)
return self.lora_manager.create_lora_manager(model, model_config)
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
def _set_active_loras(self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest]) -> None:
lora_requests: set[LoRARequest],
is_mm_input: bool = False) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
@ -64,11 +66,14 @@ class LoRAModelRunnerMixin:
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)
is_prefill=True,
is_mm_input=is_mm_input)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None:
def set_active_loras(self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False) -> None:
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: tuple[int,
@ -77,11 +82,13 @@ class LoRAModelRunnerMixin:
prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
lora_requests, is_mm_input)
@contextmanager
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
def maybe_dummy_run_with_lora(self,
lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray,
is_mm_input: bool = False):
if lora_config is None:
yield
else:
@ -117,7 +124,7 @@ class LoRAModelRunnerMixin:
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
lora_requests, is_mm_input)
yield