mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 03:13:32 +08:00
add mm_punica_warpper
Signed-off-by: bk-201 <joy25810@foxmail.com>
This commit is contained in:
parent
23baa2180b
commit
e7026a7c50
@ -27,6 +27,10 @@ argcomplete==3.5.1
|
||||
# via datamodel-code-generator
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
async-timeout==5.0.1
|
||||
# via
|
||||
# aiohttp
|
||||
# redis
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -129,6 +133,11 @@ eval-type-backport==0.2.2
|
||||
# via mteb
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# anyio
|
||||
# hypothesis
|
||||
# pytest
|
||||
fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
@ -640,7 +649,6 @@ setuptools==77.0.3
|
||||
# via
|
||||
# mamba-ssm
|
||||
# pytablewriter
|
||||
# torch
|
||||
# triton
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
@ -700,8 +708,13 @@ tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# transformers
|
||||
toml==0.10.2
|
||||
# via datamodel-code-generator
|
||||
tomli==2.2.1
|
||||
# via schemathesis
|
||||
# via
|
||||
# black
|
||||
# pytest
|
||||
# schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch==2.7.0+cu128
|
||||
@ -775,13 +788,18 @@ types-python-dateutil==2.9.0.20241206
|
||||
# via arrow
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# black
|
||||
# exceptiongroup
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# mteb
|
||||
# multidict
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# rich
|
||||
# torch
|
||||
# typer
|
||||
tzdata==2024.2
|
||||
|
||||
@ -77,6 +77,7 @@ def _not_fully_sharded_can_replace(can_replace):
|
||||
@dataclass
|
||||
class LoRAMapping(AdapterMapping):
|
||||
is_prefill: bool = False
|
||||
is_mm_input: bool = False
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
@ -410,6 +411,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# Store original shape for later reshaping
|
||||
original_shape = output.shape if output.ndim == 3 else None
|
||||
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
@ -424,6 +428,10 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
# Restore original shape if it was flattened
|
||||
if original_shape is not None:
|
||||
output = output.reshape(original_shape)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
|
||||
@ -17,14 +17,14 @@ from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.config import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLoRA,
|
||||
LoRAMapping)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
get_supported_lora_modules,
|
||||
is_regex_target_modules,
|
||||
@ -33,6 +33,7 @@ from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -311,6 +312,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[ModelConfig],
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
@ -357,6 +359,30 @@ class LoRAModelManager(AdapterModelManager):
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping"))
|
||||
# For v0 compatibility
|
||||
if model_config is not None:
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.info = self.mm_registry.create_processor(
|
||||
model_config, disable_cache=True).info
|
||||
self.supports_mm_lora = self.supports_mm and hasattr(
|
||||
self.info, "get_num_mm_encoder_tokens")
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
if self.supports_mm_lora:
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
self.mm_punica_wrapper_mapping = {
|
||||
name:
|
||||
get_punica_wrapper(
|
||||
self.info.get_num_mm_encoder_tokens(
|
||||
max_num_batched_tokens),
|
||||
max_batches=self.max_num_seqs, # TODO
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras,
|
||||
)
|
||||
for name in self.mm_mapping.tower_model
|
||||
}
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.language_model[0]] = self.punica_wrapper
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: dict[str, list[str]] = {}
|
||||
self.modules: dict[str, BaseLayerWithLoRA] = {}
|
||||
@ -452,14 +478,35 @@ class LoRAModelManager(AdapterModelManager):
|
||||
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# update lora states
|
||||
self.punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
if not self.supports_mm_lora:
|
||||
self.punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
elif mapping.is_mm_input:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.tower_model[0]].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
else:
|
||||
self.mm_punica_wrapper_mapping[
|
||||
self.mm_mapping.language_model[0]].update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
@ -476,7 +523,9 @@ class LoRAModelManager(AdapterModelManager):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
# TODO: Remove this restriction
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
if (self._filter_unsupported_mm_module(module_name)
|
||||
and not self.supports_mm_lora
|
||||
or self._get_mm_punica_wrapper(module_name) is None):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently only supports "
|
||||
"adding LoRA to language model, %s will be ignored.",
|
||||
@ -519,7 +568,11 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
if self.supports_mm_lora:
|
||||
new_module.set_mapping(
|
||||
self._get_mm_punica_wrapper(module_name))
|
||||
else:
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
@ -615,6 +668,19 @@ class LoRAModelManager(AdapterModelManager):
|
||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
|
||||
def _get_mm_punica_wrapper(self, module_name: str) -> PunicaWrapperBase:
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
if self.supports_mm_lora:
|
||||
for (
|
||||
prefix,
|
||||
punica_wrapper,
|
||||
) in self.mm_punica_wrapper_mapping.items():
|
||||
if module_name.startswith(prefix):
|
||||
return punica_wrapper
|
||||
return None
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
@ -713,9 +779,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
|
||||
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||
max_num_batched_tokens: int, vocab_size: int,
|
||||
lora_config: LoRAConfig, device: torch.device):
|
||||
lora_config: LoRAConfig, model_config: ModelConfig,
|
||||
device: torch.device):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config, device)
|
||||
vocab_size, lora_config, model_config, device)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||
@ -785,6 +852,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
@ -797,6 +865,7 @@ def create_lora_manager(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
model_config=model_config,
|
||||
device=device,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.config import LoRAConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
@ -200,6 +200,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
@ -209,6 +210,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
model_config=model_config,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
@ -279,6 +279,15 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
height=image_processor.size["longest_edge"],
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
scale_factor = hf_config.scale_factor
|
||||
|
||||
return num_image_tokens * scale_factor**2
|
||||
|
||||
|
||||
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
):
|
||||
|
||||
@ -962,6 +962,16 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
def get_num_mm_encoder_tokens(
|
||||
self,
|
||||
num_image_tokens: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
merge_size = vision_config.spatial_merge_size
|
||||
|
||||
return num_image_tokens * merge_size**2
|
||||
|
||||
|
||||
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
@ -44,6 +45,7 @@ class DummyDecoderData(NamedTuple):
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargs
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
multi_modal_token_ids: list[int]
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
@ -249,6 +251,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
str(self._get_mm_num_tokens(mm_inputs)),
|
||||
)
|
||||
|
||||
multi_modal_token_ids = copy.deepcopy(prompt_token_ids)
|
||||
if total_len < seq_len:
|
||||
prompt_token_ids.extend([0] * (seq_len - total_len))
|
||||
|
||||
@ -256,6 +259,7 @@ class MultiModalProfiler(Generic[_I]):
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
multi_modal_token_ids=multi_modal_token_ids,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens(
|
||||
|
||||
@ -263,6 +263,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
# Multimodal LoRA support
|
||||
if self.is_multimodal_model:
|
||||
self.info = self.mm_registry.create_processor(
|
||||
self.model_config, disable_cache=True).info
|
||||
self.supports_mm_lora = hasattr(self.info,
|
||||
"get_num_mm_encoder_tokens")
|
||||
else:
|
||||
self.supports_mm_lora = False
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
@ -892,12 +901,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_tokens = list[int]()
|
||||
mm_inputs = list[MultiModalKwargs]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_tokens.append(req_state.mm_positions[mm_input_id].length)
|
||||
mm_inputs.append(req_state.mm_inputs[mm_input_id])
|
||||
req_ids_pos.append(
|
||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||
@ -911,6 +922,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# encoder outputs.
|
||||
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
||||
|
||||
if self.lora_config and self.supports_mm_lora:
|
||||
mm_tokens = [
|
||||
self.info.get_num_mm_encoder_tokens(num_token)
|
||||
for num_token in mm_tokens
|
||||
]
|
||||
num_scheduled_tokens = np.array(mm_tokens, dtype=np.int32)
|
||||
self.set_active_loras(self.input_batch,
|
||||
num_scheduled_tokens,
|
||||
is_mm_input=True)
|
||||
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
@ -1826,22 +1847,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
||||
|
||||
# Create dummy batch of multimodal inputs.
|
||||
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
|
||||
dummy_mm_data = self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
mm_counts={
|
||||
dummy_data_modality: 1
|
||||
},
|
||||
).multi_modal_data
|
||||
mm_counts={dummy_data_modality: 1},
|
||||
)
|
||||
dummy_mm_kwargs = dummy_mm_data.multi_modal_data
|
||||
dummy_mm_token_ids = dummy_mm_data.multi_modal_token_ids
|
||||
|
||||
max_num_mm_items = 1 # temporary
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
[dummy_mm_kwargs] * max_num_mm_items) # ???
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs, device=self.device)
|
||||
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
if self.supports_mm_lora:
|
||||
num_scheduled_tokens_list = [
|
||||
self.info.get_num_mm_encoder_tokens(
|
||||
len(dummy_mm_token_ids))
|
||||
] * max_num_mm_items
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
lora_config = self.lora_config
|
||||
else:
|
||||
num_scheduled_tokens = None
|
||||
lora_config = None
|
||||
|
||||
with self.maybe_dummy_run_with_lora(lora_config,
|
||||
num_scheduled_tokens,
|
||||
is_mm_input=True):
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
dummy_encoder_outputs,
|
||||
|
||||
@ -50,11 +50,13 @@ class LoRAModelRunnerMixin:
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.max_position_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
return self.lora_manager.create_lora_manager(model, model_config)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
||||
def _set_active_loras(self,
|
||||
prompt_lora_mapping: tuple[int, ...],
|
||||
token_lora_mapping: tuple[int, ...],
|
||||
lora_requests: set[LoRARequest]) -> None:
|
||||
lora_requests: set[LoRARequest],
|
||||
is_mm_input: bool = False) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
@ -64,11 +66,14 @@ class LoRAModelRunnerMixin:
|
||||
# decode and this flag is generally ignored.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
is_prefill=True,
|
||||
is_mm_input=is_mm_input)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
def set_active_loras(self,
|
||||
input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False) -> None:
|
||||
|
||||
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: tuple[int,
|
||||
@ -77,11 +82,13 @@ class LoRAModelRunnerMixin:
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
lora_requests, is_mm_input)
|
||||
|
||||
@contextmanager
|
||||
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
def maybe_dummy_run_with_lora(self,
|
||||
lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
is_mm_input: bool = False):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
@ -117,7 +124,7 @@ class LoRAModelRunnerMixin:
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
lora_requests, is_mm_input)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user