From 3302f0aef39c392321567ac1400101155e365a29 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Sep 2023 13:35:11 -0700 Subject: [PATCH] rope_theta and max_position_embeddings from config (#1096) Co-authored-by: Woosuk Kwon Co-authored-by: wnma3mz --- vllm/config.py | 70 ++++++++++++++------------ vllm/model_executor/models/aquila.py | 11 ++++ vllm/model_executor/models/baichuan.py | 20 ++++++-- vllm/model_executor/models/falcon.py | 17 ++++--- vllm/model_executor/models/gpt_j.py | 16 ++++-- vllm/model_executor/models/gpt_neox.py | 12 ++++- vllm/model_executor/models/internlm.py | 20 ++++++-- vllm/model_executor/models/llama.py | 19 ++++--- vllm/model_executor/models/qwen.py | 17 +++++-- 9 files changed, 140 insertions(+), 62 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index dd92fbccd8992..f3e204af1b591 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -57,7 +57,7 @@ class ModelConfig: load_format: str, dtype: str, seed: int, - revision: Optional[str], + revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, ) -> None: @@ -73,19 +73,11 @@ class ModelConfig: self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.max_model_len = _get_and_verify_max_len(self.hf_config, + max_model_len) self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() - self.max_model_len = None - if max_model_len is not None: - derived_max_model_len = self.get_max_model_len() - if max_model_len > derived_max_model_len: - logger.warning( - f"User-specified max_model_len ({max_model_len}) is " - f"greater than the derived max_model_len " - f"({derived_max_model_len}). Make sure the value is " - "correct and within the model context size.") - self.max_model_len = max_model_len def _verify_load_format(self) -> None: load_format = self.load_format.lower() @@ -168,26 +160,7 @@ class ModelConfig: return total_num_attention_heads // parallel_config.tensor_parallel_size def get_max_model_len(self) -> int: - if self.max_model_len is not None: - return self.max_model_len - max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - for key in possible_keys: - max_len_key = getattr(self.hf_config, key, None) - if max_len_key is not None: - max_model_len = min(max_model_len, max_len_key) - return max_model_len + return self.max_model_len def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers @@ -348,3 +321,38 @@ def _get_and_verify_dtype( f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"{compute_capability[0]}.{compute_capability[1]}.") return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + + if max_model_len is None: + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size.") + return max_model_len diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 1551974112b2f..caf9b61ffa0e6 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -105,6 +105,8 @@ class AquilaAttention(nn.Module): hidden_size: int, num_heads: int, num_kv_heads: int, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, ): super().__init__() self.hidden_size = hidden_size @@ -119,6 +121,8 @@ class AquilaAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings self.qkv_proj = ColumnParallelLinear( hidden_size, @@ -140,6 +144,8 @@ class AquilaAttention(nn.Module): self.head_dim, self.scaling, rotary_dim=self.head_dim, + base=self.rope_theta, + max_position=self.max_position_embeddings, ) def forward( @@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module): def __init__(self, config: AquilaConfig): super().__init__() self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) self.self_attn = AquilaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_attention_heads, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, ) self.mlp = AquilaMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 17e971d7bb29f..277b2cc49b442 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module): hidden_size: int, num_heads: int, position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, ): super().__init__() self.hidden_size = hidden_size @@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module): tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads self.postion_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings # pylint: disable=invalid-name self.W_pack = ColumnParallelLinear( @@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module): scaling, alibi_slopes) else: self.scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - rotary_dim=self.head_dim) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim, + base=self.rope_theta, + max_position=self.max_position_embeddings) def forward( self, @@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: BaiChuanConfig, position_embedding: str): super().__init__() self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, position_embedding=position_embedding, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index dbd8a8203e4b4..e8e2171fe7552 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -161,12 +161,17 @@ class FalconAttention(nn.Module): "Rotary and alibi are mutually exclusive.") if self.use_rotary: - # TODO(zhuohan): Pass in correct `max_position`` - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.inv_norm_factor, - rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + base=rope_theta, + max_position=max_position_embeddings, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index c3e8da239acec..f8ffcdb7189a5 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -67,11 +67,17 @@ class GPTJAttention(nn.Module): scaling = self.head_size**-0.5 assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_size, - scaling, - config.rotary_dim, - is_neox_style=False) + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_size, + scaling, + config.rotary_dim, + base=rope_theta, + max_position=max_position_embeddings, + is_neox_style=False) self.warmup = False def forward( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index acbd1b47d6265..225726a630cf5 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module): scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 - self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, - scaling, rotary_dim) + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_size, + scaling, + rotary_dim, + base=rope_theta, + max_position=max_position_embeddings) def forward( self, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index fdcac02a1b276..55bd76be01409 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -59,6 +59,8 @@ class InternLMAttention(nn.Module): self, hidden_size: int, num_heads: int, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, ): super().__init__() self.hidden_size = hidden_size @@ -70,6 +72,8 @@ class InternLMAttention(nn.Module): tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings self.qkv_proj = ColumnParallelLinear( hidden_size, @@ -85,10 +89,13 @@ class InternLMAttention(nn.Module): input_is_parallel=True, perform_initialization=False, ) - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - rotary_dim=self.head_dim) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + max_position=self.max_position_embeddings, + rotary_dim=self.head_dim) def forward( self, @@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) self.self_attn = InternLMAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, ) self.mlp = InternLMMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0b7f4181a1501..79b7bed267275 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -92,6 +92,7 @@ class LlamaAttention(nn.Module): num_heads: int, num_kv_heads: int, rope_theta: float = 10000, + max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -108,6 +109,7 @@ class LlamaAttention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings self.qkv_proj = ParallelLinear.column( hidden_size, @@ -126,12 +128,14 @@ class LlamaAttention(nn.Module): perform_initialization=False, quant_config=quant_config, ) - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - base=self.rope_theta, - rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + base=self.rope_theta, + max_position=self.max_position_embeddings, + rotary_dim=self.head_dim, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -161,11 +165,14 @@ class LlamaDecoderLayer(nn.Module): self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, quant_config=quant_config, ) self.mlp = LlamaMLP( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 4ce9aea5e2c78..f572edb41db8a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -76,8 +76,13 @@ class QWenMLP(nn.Module): class QWenAttention(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, - max_position_embeddings: int): + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( @@ -109,6 +114,7 @@ class QWenAttention(nn.Module): self.head_dim, self.scaling, rotary_dim=self.head_dim, + base=rope_theta, max_position=max_position_embeddings, ) @@ -137,8 +143,11 @@ class QWenBlock(nn.Module): super().__init__() self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = QWenAttention(config.n_embd, config.num_attention_heads, - config.max_position_embeddings) + rope_theta = getattr(config, "rope_theta", 10000) + self.attn = QWenAttention(config.n_embd, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta) self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)