mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 05:17:04 +08:00
[Bugfix] In LongRoPE, decide short vs long based on max_model_len (#27431)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
7a865f2325
commit
44b5ce956d
@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811
|
|||||||
"--dtype",
|
"--dtype",
|
||||||
"half",
|
"half",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"12800",
|
"4096",
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
# lora config below
|
# lora config below
|
||||||
"--enable-lora",
|
"--enable-lora",
|
||||||
|
|||||||
@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
|
|||||||
# If the user didn't specify `max_model_len`, then use that derived from
|
# If the user didn't specify `max_model_len`, then use that derived from
|
||||||
# the model config as a default value.
|
# the model config as a default value.
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = int(derived_max_model_len)
|
# For LongRoPE, default to original_max_position_embeddings to avoid
|
||||||
|
# performance degradation for shorter sequences
|
||||||
|
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
|
||||||
|
max_model_len = int(
|
||||||
|
getattr(
|
||||||
|
hf_config, "original_max_position_embeddings", derived_max_model_len
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_model_len = int(derived_max_model_len)
|
||||||
max_model_len = current_platform.check_max_model_len(max_model_len)
|
max_model_len = current_platform.check_max_model_len(max_model_len)
|
||||||
|
|
||||||
# If the user specified a max length, make sure it is smaller than the
|
# If the user specified a max length, make sure it is smaller than the
|
||||||
# derived length from the HF model config.
|
# derived length from the HF model config.
|
||||||
elif max_model_len > derived_max_model_len:
|
elif max_model_len > derived_max_model_len:
|
||||||
|
|||||||
@ -5,8 +5,13 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .common import rotate_neox
|
from .common import rotate_neox
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||||
"""Phi3 family of models scaled rotary embedding.
|
"""Phi3 family of models scaled rotary embedding.
|
||||||
@ -43,6 +48,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
self.short_factor = short_factor
|
self.short_factor = short_factor
|
||||||
self.long_factor = long_factor
|
self.long_factor = long_factor
|
||||||
|
|
||||||
|
# Force long factors if max_model_len (runtime max length) exceeds
|
||||||
|
# original_max_position_embeddings to prevent KV cache invalidation when
|
||||||
|
# sequences cross this threshold during generation
|
||||||
|
max_model_len = get_current_vllm_config().model_config.max_model_len
|
||||||
|
self.use_long_rope = max_model_len > original_max_position_embeddings
|
||||||
|
if self.use_long_rope:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using LongRoPE scaling factors. This enables longer "
|
||||||
|
"contexts (%d tokens vs original %d tokens) at the cost of "
|
||||||
|
"some performance degradation for shorter sequences. If "
|
||||||
|
"this is not desired, set `max_model_len` to be at most %d.",
|
||||||
|
max_position_embeddings,
|
||||||
|
original_max_position_embeddings,
|
||||||
|
original_max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||||
if scale <= 1.0:
|
if scale <= 1.0:
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
@ -112,15 +133,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||||
|
|
||||||
k = self.original_max_position_embeddings
|
if self.use_long_rope:
|
||||||
long_prompt_offset = (
|
k = self.original_max_position_embeddings
|
||||||
torch.any(positions > k).float() * torch.full_like(positions, k)
|
long_prompt_offset = torch.full_like(positions, k).long()
|
||||||
).long()
|
idx = torch.add(positions, long_prompt_offset)
|
||||||
idx = (
|
else:
|
||||||
torch.add(positions, long_prompt_offset)
|
idx = positions
|
||||||
if long_prompt_offset is not None
|
|
||||||
else positions
|
|
||||||
)
|
|
||||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user