mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 23:35:01 +08:00
[Hardware][Intel-Gaudi] Enable long-contexts + LoRA support for Intel Gaudi (#12812)
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
This commit is contained in:
parent
407b5537db
commit
2880e21e3d
@ -1,12 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple, Union, final
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
|
||||
dispatch_bgmv_linear)
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
from .utils import convert_mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
@final
|
||||
@ -19,6 +25,55 @@ class PunicaWrapperHPU(PunicaWrapperBase):
|
||||
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
|
||||
max_batches, device)
|
||||
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
|
||||
extra_vocab_size, self.device, None)
|
||||
# Updating each element in `long_lora_offsets` with `lora_offset` slows
|
||||
# down perf in HPU due to a series of `strided_insert` ops during lazy
|
||||
# graph accumulation. Hence HPU appends `lora_offset` to a list and
|
||||
# converts it to a tensor only after it is ready.
|
||||
if long_lora_context:
|
||||
index_mapping_indices: List[int] = list(
|
||||
mapping.index_mapping).copy()
|
||||
long_lora_offsets: List[int] = []
|
||||
for i in range(len(index_mapping_indices)):
|
||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||
index_mapping_indices[i], 0)
|
||||
long_lora_offsets.append(lora_offset)
|
||||
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
indices_len[-1] = long_lora_offsets_tensor.shape[-1]
|
||||
|
||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self._embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self._long_lora_indices.zero_()
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -206,9 +206,10 @@ class RotaryEmbedding(CustomOp):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
|
||||
positions = positions.flatten()
|
||||
if offsets is not None:
|
||||
offsets = offsets.view(positions.shape[0], -1)
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
|
||||
num_tokens, 1, -1)
|
||||
|
||||
@ -639,12 +639,25 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
"Bias support in LoRA is not enabled in HPU yet."
|
||||
assert not self.lora_config.fully_sharded_loras, \
|
||||
"Fully sharded LoRAs is not enabled in HPU yet."
|
||||
# It's necessary to distinguish between the
|
||||
# max_position_embeddings of VLMs and LLMs.
|
||||
if hasattr(self.model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = (
|
||||
self.model.config.max_position_embeddings)
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
self.model.config.text_config.max_position_embeddings)
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size, self.lora_config, self.device,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules)
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
if self.model_config.quantization == 'inc':
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user