[misc] LoRA - Skip LoRA kernels when not required (#15152)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-25 20:33:45 -07:00 committed by GitHub
parent 33437bc6e7
commit 6c663dfd5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 113 additions and 33 deletions

View File

@ -136,6 +136,7 @@ def _lora_expand(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
@ -157,11 +158,19 @@ def _lora_expand(
identifies the the region in token_indices_sorted_by_lora_ids that identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor. offset_start (int, optional): Offset start for output_tensor.
Defaults to 0. Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False. output tensor. Defaults to False.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights: for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16] assert weight.dtype in [torch.float16, torch.bfloat16]
@ -170,6 +179,8 @@ def _lora_expand(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check. # metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@ -181,7 +192,6 @@ def _lora_expand(
inputs.device) inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False CAST_TYPE = False
@ -263,6 +273,7 @@ def _lora_expand_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:

View File

@ -17,6 +17,17 @@ class LoRAKernelMeta:
num_tokens_per_lora: torch.Tensor num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod @staticmethod
def make(max_loras: int, max_num_tokens: int, def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta": device: Union[torch.device, str]) -> "LoRAKernelMeta":
@ -47,17 +58,24 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2, lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta( return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping, token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids, active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora, num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc) lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self): def _reset(self):
self.active_lora_ids.fill_(-1) self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0) self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0) self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
""" """
@ -70,6 +88,14 @@ class LoRAKernelMeta:
self._reset() self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0) num_tokens = token_lora_mapping.size(0)
# copy token lora mapping # copy token lora mapping
@ -100,7 +126,7 @@ class LoRAKernelMeta:
def meta_args( def meta_args(
self, token_nums: int self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]: torch.Tensor, torch.Tensor]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the forward pass execution of the kernel. The function returns all the
@ -111,7 +137,11 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward
pass. pass.
""" """
return (self.token_lora_mapping[:token_nums], return (
self.token_indices_sorted_by_lora_ids[:token_nums], self.token_lora_mapping[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc, self.token_indices_sorted_by_lora_ids[:token_nums],
self.active_lora_ids) self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)

View File

@ -106,6 +106,7 @@ def _lora_shrink(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
@ -126,8 +127,16 @@ def _lora_shrink(
identifies the region in token_indices_sorted_by_lora_ids that identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights: for weight in lora_a_weights:
@ -138,6 +147,8 @@ def _lora_shrink(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check # metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
@ -146,7 +157,6 @@ def _lora_shrink(
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, (lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights) NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
@ -218,6 +228,7 @@ def _lora_shrink_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float, scaling: float,
) -> None: ) -> None:
return return

View File

@ -1242,6 +1242,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
self._dummy_run(max_num_batched_tokens, max_num_seqs) self._dummy_run(max_num_batched_tokens, max_num_seqs)
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
assert num_loras > 0
assert self.lora_manager is not None
dummy_lora_requests: list[LoRARequest] = []
with self.lora_manager.dummy_lora_cache():
for idx in range(num_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
return dummy_lora_requests
def _remove_dummy_loras(self):
# Remove dummy loras.
assert self.lora_manager is not None
self.remove_all_loras()
def _dummy_run(self, def _dummy_run(self,
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_num_seqs: int = 1) -> None: max_num_seqs: int = 1) -> None:
@ -1251,28 +1274,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
# This represents the maximum number of different requests # This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory # that will have unique loras, and therefore the max amount of
# consumption create dummy lora request copies from the lora request # memory consumption. Create dummy lora request copies from the
# passed in, which contains a lora from the lora warmup path. # lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config: if self.lora_config:
assert self.lora_manager is not None dummy_lora_requests = self._add_dummy_loras(
with self.lora_manager.dummy_lora_cache(): self.lora_config.max_loras)
for idx in range(self.lora_config.max_loras): assert len(dummy_lora_requests) == self.lora_config.max_loras
lora_id = idx + 1 dummy_lora_requests_per_seq = [
dummy_lora_request = LoRARequest( dummy_lora_requests[idx % len(dummy_lora_requests)]
lora_name=f"warmup_{lora_id}", for idx in range(max_num_seqs)
lora_int_id=lora_id, ]
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the # Profile memory usage with max_num_sequences sequences and the
# total number of tokens equal to max_num_batched_tokens. # total number of tokens equal to max_num_batched_tokens.
@ -1354,9 +1369,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
if self.lora_config: if self.lora_config:
# Remove dummy loras. self._remove_dummy_loras()
assert self.lora_manager is not None
self.remove_all_loras()
return return
def remove_all_loras(self): def remove_all_loras(self):
@ -1479,6 +1493,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
dummy_lora_id: Optional[int] = None
dummy_lora_request: LoRARequest = []
if self.lora_config:
# The goal is to capture the LoRA kernels in cuda graphs.
# for this purpose, as single dummy lora is sufficient.
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
assert len(dummy_lora_requests) == 1
dummy_lora_request = dummy_lora_requests[0]
dummy_lora_id = dummy_lora_request.lora_int_id
with self.attn_state.graph_capture(max_batch_size), graph_capture( with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context: self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
@ -1503,10 +1527,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
attn_metadata.enable_kv_scales_calculation = False attn_metadata.enable_kv_scales_calculation = False
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size, **dict(index_mapping=[dummy_lora_id] * batch_size,
prompt_mapping=[0] * batch_size, prompt_mapping=[dummy_lora_id] * batch_size,
is_prefill=False)) is_prefill=False))
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set([dummy_lora_request]),
lora_mapping)
if self.prompt_adapter_config: if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping( prompt_adapter_mapping = PromptAdapterMapping(
@ -1562,6 +1587,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
graph_runner) graph_runner)
if self.lora_config:
self._remove_dummy_loras()
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time elapsed_time = end_time - start_time