diff --git a/vllm/lora/ops/triton_ops/lora_expand.py b/vllm/lora/ops/triton_ops/lora_expand.py index b47e491ad7ed1..eacc6fb46ebd7 100644 --- a/vllm/lora/ops/triton_ops/lora_expand.py +++ b/vllm/lora/ops/triton_ops/lora_expand.py @@ -136,6 +136,7 @@ def _lora_expand( num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -157,11 +158,19 @@ def _lora_expand( identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should process. lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. offset_start (int, optional): Offset start for output_tensor. Defaults to 0. add_inputs (bool, optional): Whether to add the input tensor to the output tensor. Defaults to False. """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] for weight in lora_b_weights: assert weight.dtype in [torch.float16, torch.bfloat16] @@ -170,6 +179,8 @@ def _lora_expand( assert output_tensor.is_contiguous() # metadata sanity check. + M = inputs.size(1) + assert token_lora_mapping.size(0) == M assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( 0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) @@ -181,7 +192,6 @@ def _lora_expand( inputs.device) K = lora_b_weights[0].shape[-1] # K= rank - M = inputs.size(1) ADD_INPUTS = add_inputs MAX_LORAS = lora_ids.size(0) CAST_TYPE = False @@ -263,6 +273,7 @@ def _lora_expand_fake( num_tokens_per_lora: torch.Tensor, lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 2add1177e84c8..1dcdfc814a891 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -17,6 +17,17 @@ class LoRAKernelMeta: num_tokens_per_lora: torch.Tensor lora_token_start_loc: torch.Tensor + # The V1 architecture uses the traced torch.compile graphs to execute + # a forward pass. Things to note about this process, + # 1. The tracing infers all python scalar datatype objects into a constant + # value. + # 2. The tracing cannot handle dynamic control flow. (dynamic control flow + # is an experimental feature in pytorch) + # 3. The internals of torch.ops functions are not traced. + # We disguise the "no_lora" flag as a cpu tensor and leverage point number 3 + # to early exit from inside the lora_expand / lora_shrink torch operation. + no_lora_flag_cpu: torch.Tensor + @staticmethod def make(max_loras: int, max_num_tokens: int, device: Union[torch.device, str]) -> "LoRAKernelMeta": @@ -47,17 +58,24 @@ class LoRAKernelMeta: lora_token_start_loc = torch.zeros(max_loras + 2, dtype=torch.int32, device=device) + + no_lora_flag_cpu = torch.tensor([False], + dtype=torch.bool, + device='cpu') + return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, active_lora_ids=active_lora_ids, num_tokens_per_lora=num_tokens_per_lora, - lora_token_start_loc=lora_token_start_loc) + lora_token_start_loc=lora_token_start_loc, + no_lora_flag_cpu=no_lora_flag_cpu) def _reset(self): self.active_lora_ids.fill_(-1) self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) + self.no_lora_flag_cpu.fill_(False) def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -70,6 +88,14 @@ class LoRAKernelMeta: self._reset() + # Check and record no-lora case. + no_lora = torch.all(token_lora_mapping == -1) + self.no_lora_flag_cpu[0] = no_lora + + if no_lora: + # Early exit. LoRA kernels will not be run. + return + num_tokens = token_lora_mapping.size(0) # copy token lora mapping @@ -100,7 +126,7 @@ class LoRAKernelMeta: def meta_args( self, token_nums: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor]: + torch.Tensor, torch.Tensor]: """ This function returns the kernel metadata required for the current forward pass execution of the kernel. The function returns all the @@ -111,7 +137,11 @@ class LoRAKernelMeta: token_nums (int): Number of input tokens in the current forward pass. """ - return (self.token_lora_mapping[:token_nums], - self.token_indices_sorted_by_lora_ids[:token_nums], - self.num_tokens_per_lora, self.lora_token_start_loc, - self.active_lora_ids) + return ( + self.token_lora_mapping[:token_nums], + self.token_indices_sorted_by_lora_ids[:token_nums], + self.num_tokens_per_lora, + self.lora_token_start_loc, + self.active_lora_ids, + self.no_lora_flag_cpu, + ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink.py b/vllm/lora/ops/triton_ops/lora_shrink.py index a97c50c44f47a..82331939d859b 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink.py +++ b/vllm/lora/ops/triton_ops/lora_shrink.py @@ -106,6 +106,7 @@ def _lora_shrink( num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] scaling: float, ) -> None: """ @@ -126,8 +127,16 @@ def _lora_shrink( identifies the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should process. lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. scaling (float): Scaling factor. """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype in [torch.float16, torch.bfloat16] for weight in lora_a_weights: @@ -138,6 +147,8 @@ def _lora_shrink( assert output_tensor.is_contiguous() # metadata sanity check + M = inputs.size(0) + assert token_lora_mapping.size(0) == M assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( 0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) @@ -146,7 +157,6 @@ def _lora_shrink( (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank - M = inputs.size(0) NUM_SLICES = len(lora_a_weights) MAX_LORAS = lora_ids.size(0) @@ -218,6 +228,7 @@ def _lora_shrink_fake( num_tokens_per_lora: torch.Tensor, lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, scaling: float, ) -> None: return diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 473bd901b5b23..edbafb48c9386 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1242,6 +1242,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): max_num_seqs = self.scheduler_config.max_num_seqs self._dummy_run(max_num_batched_tokens, max_num_seqs) + def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]: + assert num_loras > 0 + assert self.lora_manager is not None + + dummy_lora_requests: list[LoRARequest] = [] + with self.lora_manager.dummy_lora_cache(): + for idx in range(num_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + return dummy_lora_requests + + def _remove_dummy_loras(self): + # Remove dummy loras. + assert self.lora_manager is not None + self.remove_all_loras() + def _dummy_run(self, max_num_batched_tokens: int, max_num_seqs: int = 1) -> None: @@ -1251,28 +1274,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. + # that will have unique loras, and therefore the max amount of + # memory consumption. Create dummy lora request copies from the + # lora request passed in, which contains a lora from the lora + # warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + dummy_lora_requests = self._add_dummy_loras( + self.lora_config.max_loras) + assert len(dummy_lora_requests) == self.lora_config.max_loras + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the # total number of tokens equal to max_num_batched_tokens. @@ -1354,9 +1369,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() if self.lora_config: - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() + self._remove_dummy_loras() + return def remove_all_loras(self): @@ -1479,6 +1493,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): dtype=self.model_config.dtype, device=self.device) + dummy_lora_id: Optional[int] = None + dummy_lora_request: LoRARequest = [] + if self.lora_config: + # The goal is to capture the LoRA kernels in cuda graphs. + # for this purpose, as single dummy lora is sufficient. + dummy_lora_requests = self._add_dummy_loras(num_loras=1) + assert len(dummy_lora_requests) == 1 + dummy_lora_request = dummy_lora_requests[0] + dummy_lora_id = dummy_lora_request.lora_int_id + with self.attn_state.graph_capture(max_batch_size), graph_capture( self.device) as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the @@ -1503,10 +1527,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): attn_metadata.enable_kv_scales_calculation = False if self.lora_config: lora_mapping = LoRAMapping( - **dict(index_mapping=[0] * batch_size, - prompt_mapping=[0] * batch_size, + **dict(index_mapping=[dummy_lora_id] * batch_size, + prompt_mapping=[dummy_lora_id] * batch_size, is_prefill=False)) - self.set_active_loras(set(), lora_mapping) + self.set_active_loras(set([dummy_lora_request]), + lora_mapping) if self.prompt_adapter_config: prompt_adapter_mapping = PromptAdapterMapping( @@ -1562,6 +1587,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.graph_runners[virtual_engine][batch_size] = ( graph_runner) + if self.lora_config: + self._remove_dummy_loras() + end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time