diff --git a/README.md b/README.md index 947d50d4ad76..fce3de3b7043 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) +- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) - Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 951fc3aac0c7..f4dd5d52ad87 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it. - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - + * - :code:`Phi3ForCausalLM` + - Phi-3 + - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. diff --git a/vllm/config.py b/vllm/config.py index 2ff42de08f8f..311a69f82257 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1056,7 +1056,7 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + if rope_scaling is not None and rope_scaling["type"] != "su": assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a5225148d782..b8361af61ae3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -338,6 +338,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): return cache +class Phi3SuScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + short_factor: List[float], + long_factor: List[float], + short_mscale: float = 1.1, + long_mscale: float = 1.225, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \ + head_size ({rotary_dim}!={head_size}).") + if is_neox_style is False: + raise ValueError( + "`Phi3SuScaledRotaryEmbedding` only supports neox_style.") + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(torch.get_default_dtype()) + self.register_buffer("short_cos_sin_cache", + short_cache, + persistent=False) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(torch.get_default_dtype()) + self.register_buffer("long_cos_sin_cache", + long_cache, + persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( + idx.device) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -349,17 +457,26 @@ def get_rope( is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - tuple(rope_scaling.items()) if rope_scaling is not None else None) + rope_scaling_args) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) else: scaling_type = rope_scaling["type"] - scaling_factor = rope_scaling["factor"] + if scaling_type != "su": + scaling_factor = rope_scaling["factor"] if scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, @@ -383,6 +500,19 @@ def get_rope( base, is_neox_style, scaling_factor, **extra_kwargs) + elif scaling_type == "su": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3SuScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, short_factor, long_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc97056804..c0aeab5dd303 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -46,6 +46,7 @@ _MODELS = { "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 016e3b039d1e..c102b40045c9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -180,6 +180,10 @@ class LlamaDecoderLayer(nn.Module): self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) @@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: