mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 20:49:08 +08:00
add mm_punica_warpper
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
23baa2180b
commit
e7026a7c50
@ -27,6 +27,10 @@ argcomplete==3.5.1
|
|||||||
# via datamodel-code-generator
|
# via datamodel-code-generator
|
||||||
arrow==1.3.0
|
arrow==1.3.0
|
||||||
# via isoduration
|
# via isoduration
|
||||||
|
async-timeout==5.0.1
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# redis
|
||||||
attrs==24.2.0
|
attrs==24.2.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
@ -129,6 +133,11 @@ eval-type-backport==0.2.2
|
|||||||
# via mteb
|
# via mteb
|
||||||
evaluate==0.4.3
|
evaluate==0.4.3
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
|
exceptiongroup==1.3.0
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# hypothesis
|
||||||
|
# pytest
|
||||||
fastparquet==2024.11.0
|
fastparquet==2024.11.0
|
||||||
# via genai-perf
|
# via genai-perf
|
||||||
fastrlock==0.8.2
|
fastrlock==0.8.2
|
||||||
@ -640,7 +649,6 @@ setuptools==77.0.3
|
|||||||
# via
|
# via
|
||||||
# mamba-ssm
|
# mamba-ssm
|
||||||
# pytablewriter
|
# pytablewriter
|
||||||
# torch
|
|
||||||
# triton
|
# triton
|
||||||
shellingham==1.5.4
|
shellingham==1.5.4
|
||||||
# via typer
|
# via typer
|
||||||
@ -700,8 +708,13 @@ tokenizers==0.21.1
|
|||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
# transformers
|
# transformers
|
||||||
|
toml==0.10.2
|
||||||
|
# via datamodel-code-generator
|
||||||
tomli==2.2.1
|
tomli==2.2.1
|
||||||
# via schemathesis
|
# via
|
||||||
|
# black
|
||||||
|
# pytest
|
||||||
|
# schemathesis
|
||||||
tomli-w==1.2.0
|
tomli-w==1.2.0
|
||||||
# via schemathesis
|
# via schemathesis
|
||||||
torch==2.7.0+cu128
|
torch==2.7.0+cu128
|
||||||
@ -775,13 +788,18 @@ types-python-dateutil==2.9.0.20241206
|
|||||||
# via arrow
|
# via arrow
|
||||||
typing-extensions==4.12.2
|
typing-extensions==4.12.2
|
||||||
# via
|
# via
|
||||||
|
# anyio
|
||||||
|
# black
|
||||||
|
# exceptiongroup
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# librosa
|
# librosa
|
||||||
# mistral-common
|
# mistral-common
|
||||||
# mteb
|
# mteb
|
||||||
|
# multidict
|
||||||
# pqdm
|
# pqdm
|
||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
|
# rich
|
||||||
# torch
|
# torch
|
||||||
# typer
|
# typer
|
||||||
tzdata==2024.2
|
tzdata==2024.2
|
||||||
|
|||||||
@ -77,6 +77,7 @@ def _not_fully_sharded_can_replace(can_replace):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LoRAMapping(AdapterMapping):
|
class LoRAMapping(AdapterMapping):
|
||||||
is_prefill: bool = False
|
is_prefill: bool = False
|
||||||
|
is_mm_input: bool = False
|
||||||
|
|
||||||
|
|
||||||
class BaseLayerWithLoRA(nn.Module):
|
class BaseLayerWithLoRA(nn.Module):
|
||||||
@ -410,6 +411,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
|
||||||
|
# Store original shape for later reshaping
|
||||||
|
original_shape = output.shape if output.ndim == 3 else None
|
||||||
|
|
||||||
# In transformers backend, x and output have extra batch dimension like
|
# In transformers backend, x and output have extra batch dimension like
|
||||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||||
# therefore we need to flatten the batch dimensions.
|
# therefore we need to flatten the batch dimensions.
|
||||||
@ -424,6 +428,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
if not current_platform.can_update_inplace():
|
if not current_platform.can_update_inplace():
|
||||||
output = lora_output
|
output = lora_output
|
||||||
|
|
||||||
|
# Restore original shape if it was flattened
|
||||||
|
if original_shape is not None:
|
||||||
|
output = output.reshape(original_shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -17,14 +17,14 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
|||||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||||
get_adapter, list_adapters,
|
get_adapter, list_adapters,
|
||||||
remove_adapter, set_adapter_mapping)
|
remove_adapter, set_adapter_mapping)
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig, ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||||
LinearScalingRotaryEmbeddingWithLoRA,
|
LinearScalingRotaryEmbeddingWithLoRA,
|
||||||
LoRAMapping)
|
LoRAMapping)
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.peft_helper import PEFTHelper
|
from vllm.lora.peft_helper import PEFTHelper
|
||||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
|
||||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||||
get_supported_lora_modules,
|
get_supported_lora_modules,
|
||||||
is_regex_target_modules,
|
is_regex_target_modules,
|
||||||
@ -33,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
|||||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -311,6 +312,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[ModelConfig],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Create a LoRAModelManager and adapter for a given model.
|
"""Create a LoRAModelManager and adapter for a given model.
|
||||||
@ -357,6 +359,30 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
# In case the model only supports LoRA for
|
# In case the model only supports LoRA for
|
||||||
# text modules (e.g. ChatGLM)
|
# text modules (e.g. ChatGLM)
|
||||||
and hasattr(self.model, "get_mm_mapping"))
|
and hasattr(self.model, "get_mm_mapping"))
|
||||||
|
# For v0 compatibility
|
||||||
|
if model_config is not None:
|
||||||
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
|
self.info = self.mm_registry.create_processor(
|
||||||
|
model_config, disable_cache=True).info
|
||||||
|
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||||
|
self.info, "get_num_mm_encoder_tokens")
|
||||||
|
else:
|
||||||
|
self.supports_mm_lora = False
|
||||||
|
if self.supports_mm_lora:
|
||||||
|
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||||
|
self.mm_punica_wrapper_mapping = {
|
||||||
|
name:
|
||||||
|
get_punica_wrapper(
|
||||||
|
self.info.get_num_mm_encoder_tokens(
|
||||||
|
max_num_batched_tokens),
|
||||||
|
max_batches=self.max_num_seqs, # TODO
|
||||||
|
device=self.device,
|
||||||
|
max_loras=self.lora_config.max_loras,
|
||||||
|
)
|
||||||
|
for name in self.mm_mapping.tower_model
|
||||||
|
}
|
||||||
|
self.mm_punica_wrapper_mapping[
|
||||||
|
self.mm_mapping.language_model[0]] = self.punica_wrapper
|
||||||
self.is_pooling_model = is_pooling_model(self.model)
|
self.is_pooling_model = is_pooling_model(self.model)
|
||||||
self.packed_modules: dict[str, list[str]] = {}
|
self.packed_modules: dict[str, list[str]] = {}
|
||||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||||
@ -452,14 +478,35 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
|
|
||||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
# update lora states
|
# update lora states
|
||||||
self.punica_wrapper.update_metadata(
|
if not self.supports_mm_lora:
|
||||||
mapping,
|
self.punica_wrapper.update_metadata(
|
||||||
self.lora_index_to_id,
|
mapping,
|
||||||
self.lora_slots + 1,
|
self.lora_index_to_id,
|
||||||
self.vocab_size,
|
self.lora_slots + 1,
|
||||||
self.lora_config.lora_extra_vocab_size,
|
self.vocab_size,
|
||||||
self.long_lora_context,
|
self.lora_config.lora_extra_vocab_size,
|
||||||
)
|
self.long_lora_context,
|
||||||
|
)
|
||||||
|
elif mapping.is_mm_input:
|
||||||
|
self.mm_punica_wrapper_mapping[
|
||||||
|
self.mm_mapping.tower_model[0]].update_metadata(
|
||||||
|
mapping,
|
||||||
|
self.lora_index_to_id,
|
||||||
|
self.lora_slots + 1,
|
||||||
|
self.vocab_size,
|
||||||
|
self.lora_config.lora_extra_vocab_size,
|
||||||
|
self.long_lora_context,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.mm_punica_wrapper_mapping[
|
||||||
|
self.mm_mapping.language_model[0]].update_metadata(
|
||||||
|
mapping,
|
||||||
|
self.lora_index_to_id,
|
||||||
|
self.lora_slots + 1,
|
||||||
|
self.vocab_size,
|
||||||
|
self.lora_config.lora_extra_vocab_size,
|
||||||
|
self.long_lora_context,
|
||||||
|
)
|
||||||
|
|
||||||
def remove_all_adapters(self):
|
def remove_all_adapters(self):
|
||||||
"""Remove all LoRAModels from the manager."""
|
"""Remove all LoRAModels from the manager."""
|
||||||
@ -476,7 +523,9 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
continue
|
continue
|
||||||
# A temporary approach for multimodal models to support LoRA
|
# A temporary approach for multimodal models to support LoRA
|
||||||
# TODO: Remove this restriction
|
# TODO: Remove this restriction
|
||||||
if self._filter_unsupported_mm_module(module_name):
|
if (self._filter_unsupported_mm_module(module_name)
|
||||||
|
and not self.supports_mm_lora
|
||||||
|
or self._get_mm_punica_wrapper(module_name) is None):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Regarding multimodal models, vLLM currently only supports "
|
"Regarding multimodal models, vLLM currently only supports "
|
||||||
"adding LoRA to language model, %s will be ignored.",
|
"adding LoRA to language model, %s will be ignored.",
|
||||||
@ -519,7 +568,11 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
self.register_module(module_name, new_module)
|
self.register_module(module_name, new_module)
|
||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
# All lora layers share the same punica_wrapper based on reference.
|
# All lora layers share the same punica_wrapper based on reference.
|
||||||
new_module.set_mapping(self.punica_wrapper)
|
if self.supports_mm_lora:
|
||||||
|
new_module.set_mapping(
|
||||||
|
self._get_mm_punica_wrapper(module_name))
|
||||||
|
else:
|
||||||
|
new_module.set_mapping(self.punica_wrapper)
|
||||||
|
|
||||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
assert isinstance(module, BaseLayerWithLoRA)
|
assert isinstance(module, BaseLayerWithLoRA)
|
||||||
@ -615,6 +668,19 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
||||||
|
"""
|
||||||
|
TODO
|
||||||
|
"""
|
||||||
|
if self.supports_mm_lora:
|
||||||
|
for (
|
||||||
|
prefix,
|
||||||
|
punica_wrapper,
|
||||||
|
) in self.mm_punica_wrapper_mapping.items():
|
||||||
|
if module_name.startswith(prefix):
|
||||||
|
return punica_wrapper
|
||||||
|
return None
|
||||||
|
|
||||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
parts = module_full_name.split(".")
|
parts = module_full_name.split(".")
|
||||||
module_name = parts[-1]
|
module_name = parts[-1]
|
||||||
@ -713,9 +779,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||||||
|
|
||||||
def __init__(self, model: nn.Module, max_num_seqs: int,
|
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||||
max_num_batched_tokens: int, vocab_size: int,
|
max_num_batched_tokens: int, vocab_size: int,
|
||||||
lora_config: LoRAConfig, device: torch.device):
|
lora_config: LoRAConfig, model_config: ModelConfig,
|
||||||
|
device: torch.device):
|
||||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||||
vocab_size, lora_config, device)
|
vocab_size, lora_config, model_config, device)
|
||||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||||
self.capacity, self.deactivate_adapter)
|
self.capacity, self.deactivate_adapter)
|
||||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||||
@ -785,6 +852,7 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
|
model_config: ModelConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||||
**kwargs) -> LoRAModelManager:
|
**kwargs) -> LoRAModelManager:
|
||||||
@ -797,6 +865,7 @@ def create_lora_manager(
|
|||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
|
model_config=model_config,
|
||||||
device=device,
|
device=device,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return lora_manager
|
return lora_manager
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker,
|
|||||||
list_adapters_worker,
|
list_adapters_worker,
|
||||||
set_active_adapters_worker)
|
set_active_adapters_worker)
|
||||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig, ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager, create_lora_manager)
|
LRUCacheLoRAModelManager, create_lora_manager)
|
||||||
@ -200,6 +200,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
def create_lora_manager(
|
def create_lora_manager(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
lora_manager = create_lora_manager(
|
lora_manager = create_lora_manager(
|
||||||
model,
|
model,
|
||||||
@ -209,6 +210,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
|||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
self._adapter_manager = lora_manager
|
self._adapter_manager = lora_manager
|
||||||
return lora_manager.model
|
return lora_manager.model
|
||||||
|
|||||||
@ -279,6 +279,15 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
|||||||
height=image_processor.size["longest_edge"],
|
height=image_processor.size["longest_edge"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(
|
||||||
|
self,
|
||||||
|
num_image_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
scale_factor = hf_config.scale_factor
|
||||||
|
|
||||||
|
return num_image_tokens * scale_factor**2
|
||||||
|
|
||||||
|
|
||||||
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||||
):
|
):
|
||||||
|
|||||||
@ -962,6 +962,16 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
|||||||
image_processor=None,
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_mm_encoder_tokens(
|
||||||
|
self,
|
||||||
|
num_image_tokens: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
return num_image_tokens * merge_size**2
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import copy
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -44,6 +45,7 @@ class DummyDecoderData(NamedTuple):
|
|||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
multi_modal_data: MultiModalKwargs
|
multi_modal_data: MultiModalKwargs
|
||||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||||
|
multi_modal_token_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||||
@ -249,6 +251,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
str(self._get_mm_num_tokens(mm_inputs)),
|
str(self._get_mm_num_tokens(mm_inputs)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multi_modal_token_ids = copy.deepcopy(prompt_token_ids)
|
||||||
if total_len < seq_len:
|
if total_len < seq_len:
|
||||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||||
|
|
||||||
@ -256,6 +259,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
|
multi_modal_token_ids=multi_modal_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mm_max_tokens(
|
def get_mm_max_tokens(
|
||||||
|
|||||||
@ -263,6 +263,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
|
# Multimodal LoRA support
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
self.info = self.mm_registry.create_processor(
|
||||||
|
self.model_config, disable_cache=True).info
|
||||||
|
self.supports_mm_lora = hasattr(self.info,
|
||||||
|
"get_num_mm_encoder_tokens")
|
||||||
|
else:
|
||||||
|
self.supports_mm_lora = False
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||||
"""
|
"""
|
||||||
Update the order of requests in the batch based on the attention
|
Update the order of requests in the batch based on the attention
|
||||||
@ -892,12 +901,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
|
mm_tokens = list[int]()
|
||||||
mm_inputs = list[MultiModalKwargs]()
|
mm_inputs = list[MultiModalKwargs]()
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
for mm_input_id in encoder_input_ids:
|
for mm_input_id in encoder_input_ids:
|
||||||
|
mm_tokens.append(req_state.mm_positions[mm_input_id].length)
|
||||||
mm_inputs.append(req_state.mm_inputs[mm_input_id])
|
mm_inputs.append(req_state.mm_inputs[mm_input_id])
|
||||||
req_ids_pos.append(
|
req_ids_pos.append(
|
||||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||||
@ -911,6 +922,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# encoder outputs.
|
# encoder outputs.
|
||||||
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
||||||
|
|
||||||
|
if self.lora_config and self.supports_mm_lora:
|
||||||
|
mm_tokens = [
|
||||||
|
self.info.get_num_mm_encoder_tokens(num_token)
|
||||||
|
for num_token in mm_tokens
|
||||||
|
]
|
||||||
|
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
|
||||||
|
self.set_active_loras(self.input_batch,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
is_mm_input=True)
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
@ -1826,22 +1847,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
encoder_budget, max_num_mm_items, dummy_data_modality)
|
||||||
|
|
||||||
# Create dummy batch of multimodal inputs.
|
# Create dummy batch of multimodal inputs.
|
||||||
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
|
dummy_mm_data = self.mm_registry.get_decoder_dummy_data(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
seq_len=self.max_num_tokens,
|
seq_len=self.max_num_tokens,
|
||||||
mm_counts={
|
mm_counts={dummy_data_modality: 1},
|
||||||
dummy_data_modality: 1
|
)
|
||||||
},
|
dummy_mm_kwargs = dummy_mm_data.multi_modal_data
|
||||||
).multi_modal_data
|
dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids
|
||||||
|
|
||||||
|
max_num_mm_items = 1 # temporary
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||||
[dummy_mm_kwargs] * max_num_mm_items)
|
[dummy_mm_kwargs] * max_num_mm_items) # ???
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
batched_dummy_mm_inputs, device=self.device)
|
batched_dummy_mm_inputs, device=self.device)
|
||||||
|
|
||||||
# Run multimodal encoder.
|
if self.supports_mm_lora:
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
num_scheduled_tokens_list = [
|
||||||
**batched_dummy_mm_inputs)
|
self.info.get_num_mm_encoder_tokens(
|
||||||
|
len(dummy_mm_token_ids))
|
||||||
|
] * max_num_mm_items
|
||||||
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||||
|
dtype=np.int32)
|
||||||
|
lora_config = self.lora_config
|
||||||
|
else:
|
||||||
|
num_scheduled_tokens = None
|
||||||
|
lora_config = None
|
||||||
|
|
||||||
|
with self.maybe_dummy_run_with_lora(lora_config,
|
||||||
|
num_scheduled_tokens,
|
||||||
|
is_mm_input=True):
|
||||||
|
# Run multimodal encoder.
|
||||||
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
|
**batched_dummy_mm_inputs)
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
dummy_encoder_outputs,
|
dummy_encoder_outputs,
|
||||||
|
|||||||
@ -50,11 +50,13 @@ class LoRAModelRunnerMixin:
|
|||||||
model.embedding_padding_modules,
|
model.embedding_padding_modules,
|
||||||
max_position_embeddings=text_config.max_position_embeddings,
|
max_position_embeddings=text_config.max_position_embeddings,
|
||||||
)
|
)
|
||||||
return self.lora_manager.create_lora_manager(model)
|
return self.lora_manager.create_lora_manager(model, model_config)
|
||||||
|
|
||||||
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
def _set_active_loras(self,
|
||||||
|
prompt_lora_mapping: tuple[int, ...],
|
||||||
token_lora_mapping: tuple[int, ...],
|
token_lora_mapping: tuple[int, ...],
|
||||||
lora_requests: set[LoRARequest]) -> None:
|
lora_requests: set[LoRARequest],
|
||||||
|
is_mm_input: bool = False) -> None:
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
|
||||||
@ -64,11 +66,14 @@ class LoRAModelRunnerMixin:
|
|||||||
# decode and this flag is generally ignored.
|
# decode and this flag is generally ignored.
|
||||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||||
prompt_lora_mapping,
|
prompt_lora_mapping,
|
||||||
is_prefill=True)
|
is_prefill=True,
|
||||||
|
is_mm_input=is_mm_input)
|
||||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||||
|
|
||||||
def set_active_loras(self, input_batch: InputBatch,
|
def set_active_loras(self,
|
||||||
num_scheduled_tokens: np.ndarray) -> None:
|
input_batch: InputBatch,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
is_mm_input: bool = False) -> None:
|
||||||
|
|
||||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||||
token_lora_mapping: tuple[int,
|
token_lora_mapping: tuple[int,
|
||||||
@ -77,11 +82,13 @@ class LoRAModelRunnerMixin:
|
|||||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||||
lora_requests)
|
lora_requests, is_mm_input)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
def maybe_dummy_run_with_lora(self,
|
||||||
num_scheduled_tokens: np.ndarray):
|
lora_config: LoRAConfig,
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
is_mm_input: bool = False):
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
@ -117,7 +124,7 @@ class LoRAModelRunnerMixin:
|
|||||||
|
|
||||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||||
tuple(token_lora_mapping),
|
tuple(token_lora_mapping),
|
||||||
lora_requests)
|
lora_requests, is_mm_input)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user