mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 09:55:53 +08:00
[misc] LoRA - Skip LoRA kernels when not required (#15152)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
33437bc6e7
commit
6c663dfd5e
@ -136,6 +136,7 @@ def _lora_expand(
|
|||||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||||
|
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -157,11 +158,19 @@ def _lora_expand(
|
|||||||
identifies the the region in token_indices_sorted_by_lora_ids that
|
identifies the the region in token_indices_sorted_by_lora_ids that
|
||||||
LoRA lora_ids[i] should process.
|
LoRA lora_ids[i] should process.
|
||||||
lora_ids (torch.Tensor): LoRA ids to process.
|
lora_ids (torch.Tensor): LoRA ids to process.
|
||||||
|
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||||
|
if there are any requests that require LoRA.
|
||||||
offset_start (int, optional): Offset start for output_tensor.
|
offset_start (int, optional): Offset start for output_tensor.
|
||||||
Defaults to 0.
|
Defaults to 0.
|
||||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||||
output tensor. Defaults to False.
|
output tensor. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert no_lora_flag_cpu.numel() == 1
|
||||||
|
if no_lora_flag_cpu.item():
|
||||||
|
# None of the inputs require LoRA.
|
||||||
|
return
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
for weight in lora_b_weights:
|
for weight in lora_b_weights:
|
||||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||||
@ -170,6 +179,8 @@ def _lora_expand(
|
|||||||
assert output_tensor.is_contiguous()
|
assert output_tensor.is_contiguous()
|
||||||
|
|
||||||
# metadata sanity check.
|
# metadata sanity check.
|
||||||
|
M = inputs.size(1)
|
||||||
|
assert token_lora_mapping.size(0) == M
|
||||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||||
0)
|
0)
|
||||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||||
@ -181,7 +192,6 @@ def _lora_expand(
|
|||||||
inputs.device)
|
inputs.device)
|
||||||
|
|
||||||
K = lora_b_weights[0].shape[-1] # K= rank
|
K = lora_b_weights[0].shape[-1] # K= rank
|
||||||
M = inputs.size(1)
|
|
||||||
ADD_INPUTS = add_inputs
|
ADD_INPUTS = add_inputs
|
||||||
MAX_LORAS = lora_ids.size(0)
|
MAX_LORAS = lora_ids.size(0)
|
||||||
CAST_TYPE = False
|
CAST_TYPE = False
|
||||||
@ -263,6 +273,7 @@ def _lora_expand_fake(
|
|||||||
num_tokens_per_lora: torch.Tensor,
|
num_tokens_per_lora: torch.Tensor,
|
||||||
lora_token_start_loc: torch.Tensor,
|
lora_token_start_loc: torch.Tensor,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
|
no_lora_flag_cpu: torch.Tensor,
|
||||||
offset_start: int = 0,
|
offset_start: int = 0,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -17,6 +17,17 @@ class LoRAKernelMeta:
|
|||||||
num_tokens_per_lora: torch.Tensor
|
num_tokens_per_lora: torch.Tensor
|
||||||
lora_token_start_loc: torch.Tensor
|
lora_token_start_loc: torch.Tensor
|
||||||
|
|
||||||
|
# The V1 architecture uses the traced torch.compile graphs to execute
|
||||||
|
# a forward pass. Things to note about this process,
|
||||||
|
# 1. The tracing infers all python scalar datatype objects into a constant
|
||||||
|
# value.
|
||||||
|
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
|
||||||
|
# is an experimental feature in pytorch)
|
||||||
|
# 3. The internals of torch.ops functions are not traced.
|
||||||
|
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
|
||||||
|
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||||
|
no_lora_flag_cpu: torch.Tensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(max_loras: int, max_num_tokens: int,
|
def make(max_loras: int, max_num_tokens: int,
|
||||||
device: Union[torch.device, str]) -> "LoRAKernelMeta":
|
device: Union[torch.device, str]) -> "LoRAKernelMeta":
|
||||||
@ -47,17 +58,24 @@ class LoRAKernelMeta:
|
|||||||
lora_token_start_loc = torch.zeros(max_loras + 2,
|
lora_token_start_loc = torch.zeros(max_loras + 2,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
no_lora_flag_cpu = torch.tensor([False],
|
||||||
|
dtype=torch.bool,
|
||||||
|
device='cpu')
|
||||||
|
|
||||||
return LoRAKernelMeta(
|
return LoRAKernelMeta(
|
||||||
token_lora_mapping=token_lora_mapping,
|
token_lora_mapping=token_lora_mapping,
|
||||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||||
active_lora_ids=active_lora_ids,
|
active_lora_ids=active_lora_ids,
|
||||||
num_tokens_per_lora=num_tokens_per_lora,
|
num_tokens_per_lora=num_tokens_per_lora,
|
||||||
lora_token_start_loc=lora_token_start_loc)
|
lora_token_start_loc=lora_token_start_loc,
|
||||||
|
no_lora_flag_cpu=no_lora_flag_cpu)
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self):
|
||||||
self.active_lora_ids.fill_(-1)
|
self.active_lora_ids.fill_(-1)
|
||||||
self.num_tokens_per_lora.fill_(0)
|
self.num_tokens_per_lora.fill_(0)
|
||||||
self.lora_token_start_loc.fill_(0)
|
self.lora_token_start_loc.fill_(0)
|
||||||
|
self.no_lora_flag_cpu.fill_(False)
|
||||||
|
|
||||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
@ -70,6 +88,14 @@ class LoRAKernelMeta:
|
|||||||
|
|
||||||
self._reset()
|
self._reset()
|
||||||
|
|
||||||
|
# Check and record no-lora case.
|
||||||
|
no_lora = torch.all(token_lora_mapping == -1)
|
||||||
|
self.no_lora_flag_cpu[0] = no_lora
|
||||||
|
|
||||||
|
if no_lora:
|
||||||
|
# Early exit. LoRA kernels will not be run.
|
||||||
|
return
|
||||||
|
|
||||||
num_tokens = token_lora_mapping.size(0)
|
num_tokens = token_lora_mapping.size(0)
|
||||||
|
|
||||||
# copy token lora mapping
|
# copy token lora mapping
|
||||||
@ -100,7 +126,7 @@ class LoRAKernelMeta:
|
|||||||
def meta_args(
|
def meta_args(
|
||||||
self, token_nums: int
|
self, token_nums: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
torch.Tensor]:
|
torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function returns the kernel metadata required for the current
|
This function returns the kernel metadata required for the current
|
||||||
forward pass execution of the kernel. The function returns all the
|
forward pass execution of the kernel. The function returns all the
|
||||||
@ -111,7 +137,11 @@ class LoRAKernelMeta:
|
|||||||
token_nums (int): Number of input tokens in the current forward
|
token_nums (int): Number of input tokens in the current forward
|
||||||
pass.
|
pass.
|
||||||
"""
|
"""
|
||||||
return (self.token_lora_mapping[:token_nums],
|
return (
|
||||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
self.token_lora_mapping[:token_nums],
|
||||||
self.num_tokens_per_lora, self.lora_token_start_loc,
|
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||||
self.active_lora_ids)
|
self.num_tokens_per_lora,
|
||||||
|
self.lora_token_start_loc,
|
||||||
|
self.active_lora_ids,
|
||||||
|
self.no_lora_flag_cpu,
|
||||||
|
)
|
||||||
|
|||||||
@ -106,6 +106,7 @@ def _lora_shrink(
|
|||||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||||
|
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -126,8 +127,16 @@ def _lora_shrink(
|
|||||||
identifies the region in token_indices_sorted_by_lora_ids that
|
identifies the region in token_indices_sorted_by_lora_ids that
|
||||||
LoRA lora_ids[i] should process.
|
LoRA lora_ids[i] should process.
|
||||||
lora_ids (torch.Tensor): LoRA ids to process.
|
lora_ids (torch.Tensor): LoRA ids to process.
|
||||||
|
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||||
|
if there are any requests that require LoRA.
|
||||||
scaling (float): Scaling factor.
|
scaling (float): Scaling factor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert no_lora_flag_cpu.numel() == 1
|
||||||
|
if no_lora_flag_cpu.item():
|
||||||
|
# None of the inputs require LoRA.
|
||||||
|
return
|
||||||
|
|
||||||
assert inputs.dtype == lora_a_weights[0].dtype
|
assert inputs.dtype == lora_a_weights[0].dtype
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||||
for weight in lora_a_weights:
|
for weight in lora_a_weights:
|
||||||
@ -138,6 +147,8 @@ def _lora_shrink(
|
|||||||
assert output_tensor.is_contiguous()
|
assert output_tensor.is_contiguous()
|
||||||
|
|
||||||
# metadata sanity check
|
# metadata sanity check
|
||||||
|
M = inputs.size(0)
|
||||||
|
assert token_lora_mapping.size(0) == M
|
||||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||||
0)
|
0)
|
||||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||||
@ -146,7 +157,6 @@ def _lora_shrink(
|
|||||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||||
M = inputs.size(0)
|
|
||||||
NUM_SLICES = len(lora_a_weights)
|
NUM_SLICES = len(lora_a_weights)
|
||||||
MAX_LORAS = lora_ids.size(0)
|
MAX_LORAS = lora_ids.size(0)
|
||||||
|
|
||||||
@ -218,6 +228,7 @@ def _lora_shrink_fake(
|
|||||||
num_tokens_per_lora: torch.Tensor,
|
num_tokens_per_lora: torch.Tensor,
|
||||||
lora_token_start_loc: torch.Tensor,
|
lora_token_start_loc: torch.Tensor,
|
||||||
lora_ids: torch.Tensor,
|
lora_ids: torch.Tensor,
|
||||||
|
no_lora_flag_cpu: torch.Tensor,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1242,6 +1242,29 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
self._dummy_run(max_num_batched_tokens, max_num_seqs)
|
self._dummy_run(max_num_batched_tokens, max_num_seqs)
|
||||||
|
|
||||||
|
def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
|
||||||
|
assert num_loras > 0
|
||||||
|
assert self.lora_manager is not None
|
||||||
|
|
||||||
|
dummy_lora_requests: list[LoRARequest] = []
|
||||||
|
with self.lora_manager.dummy_lora_cache():
|
||||||
|
for idx in range(num_loras):
|
||||||
|
lora_id = idx + 1
|
||||||
|
dummy_lora_request = LoRARequest(
|
||||||
|
lora_name=f"warmup_{lora_id}",
|
||||||
|
lora_int_id=lora_id,
|
||||||
|
lora_path="/not/a/real/path",
|
||||||
|
)
|
||||||
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||||
|
rank=LORA_WARMUP_RANK)
|
||||||
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
|
return dummy_lora_requests
|
||||||
|
|
||||||
|
def _remove_dummy_loras(self):
|
||||||
|
# Remove dummy loras.
|
||||||
|
assert self.lora_manager is not None
|
||||||
|
self.remove_all_loras()
|
||||||
|
|
||||||
def _dummy_run(self,
|
def _dummy_run(self,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
max_num_seqs: int = 1) -> None:
|
max_num_seqs: int = 1) -> None:
|
||||||
@ -1251,28 +1274,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
|
|
||||||
# This represents the maximum number of different requests
|
# This represents the maximum number of different requests
|
||||||
# that will have unique loras, an therefore the max amount of memory
|
# that will have unique loras, and therefore the max amount of
|
||||||
# consumption create dummy lora request copies from the lora request
|
# memory consumption. Create dummy lora request copies from the
|
||||||
# passed in, which contains a lora from the lora warmup path.
|
# lora request passed in, which contains a lora from the lora
|
||||||
|
# warmup path.
|
||||||
dummy_lora_requests: List[LoRARequest] = []
|
dummy_lora_requests: List[LoRARequest] = []
|
||||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert self.lora_manager is not None
|
dummy_lora_requests = self._add_dummy_loras(
|
||||||
with self.lora_manager.dummy_lora_cache():
|
self.lora_config.max_loras)
|
||||||
for idx in range(self.lora_config.max_loras):
|
assert len(dummy_lora_requests) == self.lora_config.max_loras
|
||||||
lora_id = idx + 1
|
dummy_lora_requests_per_seq = [
|
||||||
dummy_lora_request = LoRARequest(
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||||
lora_name=f"warmup_{lora_id}",
|
for idx in range(max_num_seqs)
|
||||||
lora_int_id=lora_id,
|
]
|
||||||
lora_path="/not/a/real/path",
|
|
||||||
)
|
|
||||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
|
||||||
rank=LORA_WARMUP_RANK)
|
|
||||||
dummy_lora_requests.append(dummy_lora_request)
|
|
||||||
dummy_lora_requests_per_seq = [
|
|
||||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
|
||||||
for idx in range(max_num_seqs)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Profile memory usage with max_num_sequences sequences and the
|
# Profile memory usage with max_num_sequences sequences and the
|
||||||
# total number of tokens equal to max_num_batched_tokens.
|
# total number of tokens equal to max_num_batched_tokens.
|
||||||
@ -1354,9 +1369,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
# Remove dummy loras.
|
self._remove_dummy_loras()
|
||||||
assert self.lora_manager is not None
|
|
||||||
self.remove_all_loras()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
@ -1479,6 +1493,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
|
dummy_lora_id: Optional[int] = None
|
||||||
|
dummy_lora_request: LoRARequest = []
|
||||||
|
if self.lora_config:
|
||||||
|
# The goal is to capture the LoRA kernels in cuda graphs.
|
||||||
|
# for this purpose, as single dummy lora is sufficient.
|
||||||
|
dummy_lora_requests = self._add_dummy_loras(num_loras=1)
|
||||||
|
assert len(dummy_lora_requests) == 1
|
||||||
|
dummy_lora_request = dummy_lora_requests[0]
|
||||||
|
dummy_lora_id = dummy_lora_request.lora_int_id
|
||||||
|
|
||||||
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
||||||
self.device) as graph_capture_context:
|
self.device) as graph_capture_context:
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
@ -1503,10 +1527,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
attn_metadata.enable_kv_scales_calculation = False
|
attn_metadata.enable_kv_scales_calculation = False
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
lora_mapping = LoRAMapping(
|
lora_mapping = LoRAMapping(
|
||||||
**dict(index_mapping=[0] * batch_size,
|
**dict(index_mapping=[dummy_lora_id] * batch_size,
|
||||||
prompt_mapping=[0] * batch_size,
|
prompt_mapping=[dummy_lora_id] * batch_size,
|
||||||
is_prefill=False))
|
is_prefill=False))
|
||||||
self.set_active_loras(set(), lora_mapping)
|
self.set_active_loras(set([dummy_lora_request]),
|
||||||
|
lora_mapping)
|
||||||
|
|
||||||
if self.prompt_adapter_config:
|
if self.prompt_adapter_config:
|
||||||
prompt_adapter_mapping = PromptAdapterMapping(
|
prompt_adapter_mapping = PromptAdapterMapping(
|
||||||
@ -1562,6 +1587,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][batch_size] = (
|
||||||
graph_runner)
|
graph_runner)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self._remove_dummy_loras()
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user