mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Model] Support for fairseq2 Llama (#11442)
Signed-off-by: Martin Gleize <mgleize@meta.com> Co-authored-by: mgleize user <mgleize@a100-st-p4de24xlarge-4.fair-a100.hpcaas>
This commit is contained in:
parent
81763c58a0
commit
bbe5f9de7d
@ -69,6 +69,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"),
|
||||
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
|
||||
|
||||
@ -30,4 +30,5 @@ marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
|
||||
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
|
||||
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
|
||||
qqq, HandH1998/QQQ-Llama-3-8b, main
|
||||
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
|
||||
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
|
||||
None, mgleize/fairseq2-dummy-Llama-3.2-1B, main
|
||||
@ -20,12 +20,13 @@ def test_weight_loading(vllm_runner):
|
||||
"""
|
||||
Test parameter weight loading with tp>1.
|
||||
"""
|
||||
with vllm_runner(model_name=MODEL_NAME,
|
||||
revision=REVISION,
|
||||
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
|
||||
quantization=QUANTIZATION,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=2) as model:
|
||||
with vllm_runner(
|
||||
model_name=MODEL_NAME,
|
||||
revision=REVISION,
|
||||
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
|
||||
quantization=None if QUANTIZATION == "None" else QUANTIZATION,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=2) as model:
|
||||
|
||||
output = model.generate_greedy("Hello world!", max_tokens=20)
|
||||
print(output)
|
||||
|
||||
@ -344,11 +344,13 @@ class ColumnParallelLinear(LinearBase):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
param_data = param.data
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
@ -546,6 +548,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
@ -554,9 +561,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
@ -941,6 +946,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
@ -964,9 +974,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
@ -1070,6 +1078,10 @@ class RowParallelLinear(LinearBase):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
@ -1085,9 +1097,7 @@ class RowParallelLinear(LinearBase):
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if input_dim is not None and not use_bitsandbytes_4bit:
|
||||
if input_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
|
||||
@ -182,6 +182,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: Optional[list[str]] = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
@ -218,6 +221,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: Optional[list[str]],
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
@ -249,6 +253,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
@ -298,7 +305,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt)
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
@ -340,6 +348,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||
True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
|
||||
None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
@ -353,7 +363,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True)
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
|
||||
151
vllm/model_executor/models/fairseq2_llama.py
Normal file
151
vllm/model_executor/models/fairseq2_llama.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright 2024 The vLLM team.
|
||||
# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Llama model for fairseq2 weights."""
|
||||
|
||||
from typing import Iterable, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import set_weight_attrs
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
|
||||
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# For the model loader to read only the relevant checkpoint files
|
||||
self.allow_patterns_overrides = [
|
||||
# either the full checkpoint
|
||||
"model.pt",
|
||||
# or the tp-sharded checkpoint of the current rank
|
||||
f"model.{self.tp_rank}.pt",
|
||||
]
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
|
||||
# { "model_key": my_model_name, "my_model_name": state_dict }
|
||||
# which we first need to unpack
|
||||
weights_wrapped = dict(weights)
|
||||
weights = weights_wrapped[
|
||||
weights_wrapped["model_key"]].items() # type: ignore
|
||||
|
||||
# remap keys
|
||||
fs2_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"decoder_frontend.embed.": "model.embed_tokens.",
|
||||
"decoder.": "model.",
|
||||
"final_proj.": "lm_head.",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".self_attn_layer_norm.": ".input_layernorm.",
|
||||
".ffn_layer_norm.": ".post_attention_layernorm.",
|
||||
".self_attn.output_proj.": ".self_attn.o_proj.",
|
||||
".ffn.gate_proj.": ".mlp.gate_proj.",
|
||||
".ffn.inner_proj.": ".mlp.up_proj.",
|
||||
".ffn.output_proj.": ".mlp.down_proj.",
|
||||
".layer_norm.": ".norm.",
|
||||
},
|
||||
)
|
||||
weights = fs2_to_vllm_mapper.apply(weights)
|
||||
|
||||
params = dict(self.named_parameters())
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
(self.reshape_fairseq2_weights(name, loaded_weight, params)
|
||||
for name, loaded_weight in weights))
|
||||
|
||||
def flag_sharded_weights(self, params: dict[str, Parameter]):
|
||||
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
|
||||
for name, param in params.items():
|
||||
modules = name.split(".")
|
||||
if "norm" in name and len(param.size()) < 2:
|
||||
# layer norms are not sharded
|
||||
continue
|
||||
elif any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
||||
# for now we repeat embedding layers for compatibility
|
||||
continue
|
||||
else:
|
||||
# all other layers are sharded
|
||||
set_weight_attrs(param, {"is_sharded_weight": True})
|
||||
|
||||
def reshape_fairseq2_weights(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
params: dict[str, Parameter],
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
"""Reshape fairseq2's weights."""
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
# check for a sharded weight on dim 0
|
||||
if attn_in // self.tp_size == w.size()[0]:
|
||||
attn_in //= self.tp_size
|
||||
n_heads //= self.tp_size
|
||||
attn_out = self.config.hidden_size
|
||||
return (w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1,
|
||||
2).reshape(attn_in, attn_out))
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "k_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
|
||||
elif "q_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
|
||||
# We make the loaded weights compatible with both
|
||||
# full checkpoints and tp sharded checkpoints.
|
||||
# Embeddings are repeated to fit the vocab size.
|
||||
# Other weights are flagged for the weight_loader calls.
|
||||
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
||||
# Embeddings are sharded on dim 0
|
||||
dim = 0
|
||||
# In fairseq2, vocab size has to be divisible by tp_size
|
||||
# so we don't worry about padding
|
||||
if self.tp_size > 1 and loaded_weight.shape[
|
||||
dim] < self.config.vocab_size:
|
||||
assert loaded_weight.shape[
|
||||
dim] * self.tp_size == self.config.vocab_size, \
|
||||
"vocab_size should be divisible by tp_size."
|
||||
repeats = [1] * len(loaded_weight.size())
|
||||
repeats[dim] = self.tp_size
|
||||
# repeat to match vocab size and to be easily 'narrow'able
|
||||
loaded_weight = loaded_weight.repeat(repeats)
|
||||
set_weight_attrs(params[name], {"is_sharded_weight": False})
|
||||
# if embeddings are sharded, the rest is too
|
||||
if "embed_tokens" in modules:
|
||||
self.flag_sharded_weights(params)
|
||||
|
||||
return name, loaded_weight
|
||||
@ -47,6 +47,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
|
||||
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
|
||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user