diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7a9d5237ab75..d9de0f3cfeb3 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): and len(packed_modules_list) == 3) +#TODO: Implement this +class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): + pass + + class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 251d95e41dc3..566149c9cf24 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -52,6 +52,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, @@ -1181,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, super().__init__() config: MllamaConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + self.config = config self.quant_config = quant_config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size @@ -1517,6 +1519,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, updated_params.add(name) return updated_params + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_model") + def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: for mask in sparse_mask: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5f39f2fa4947..72ff9d66a689 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,6 +16,7 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, @@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import ( from vllm.worker.utils import assert_enc_dec_mr_supported_scenario logger = init_logger(__name__) +LORA_WARMUP_RANK = 8 @dataclasses.dataclass(frozen=True) @@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): @@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, and therefore the max amount of + # memory consumption. Create dummy lora request copies from the + # lora request passed in, which contains a lora from the lora + # warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + dummy_lora_requests = self._add_dummy_loras( + self.lora_config.max_loras) + assert len(dummy_lora_requests) == self.lora_config.max_loras + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] @@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): block_tables=None, encoder_seq_data=encoder_dummy_data.seq_data, cross_block_table=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data, multi_modal_placeholders=decoder_dummy_data.