[New Model]mBART model (#22883)

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
汪志鹏 2025-08-16 20:16:58 +08:00 committed by GitHub
parent 4dff91c93d
commit 829bbd7882
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 717 additions and 92 deletions

View File

@ -330,6 +330,7 @@ th {
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models.

View File

@ -2,9 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART
encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments.
"""
import argparse
from typing import NamedTuple, Optional
from vllm import LLM, SamplingParams
from vllm.inputs import (
ExplicitEncoderDecoderPrompt,
@ -14,119 +19,175 @@ from vllm.inputs import (
)
def create_prompts(tokenizer):
# Test prompts
#
# This section shows all of the valid ways to prompt an
# encoder/decoder model.
#
# - Helpers for building prompts
text_prompt_raw = "Hello, my name is"
text_prompt = TextPrompt(prompt="The president of the United States is")
class ModelRequestData(NamedTuple):
"""
Holds the configuration for a specific model, including its
HuggingFace ID and the prompts to use for the demo.
"""
model_id: str
encoder_prompts: list
decoder_prompts: list
hf_overrides: Optional[dict] = None
def get_bart_config() -> ModelRequestData:
"""
Returns the configuration for facebook/bart-large-cnn.
This uses the exact test cases from the original script.
"""
encoder_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"An encoder prompt",
]
decoder_prompts = [
"A decoder prompt",
"Another decoder prompt",
]
return ModelRequestData(
model_id="facebook/bart-large-cnn",
encoder_prompts=encoder_prompts,
decoder_prompts=decoder_prompts,
)
def get_mbart_config() -> ModelRequestData:
"""
Returns the configuration for facebook/mbart-large-en-ro.
This uses prompts suitable for an English-to-Romanian translation task.
"""
encoder_prompts = [
"The quick brown fox jumps over the lazy dog.",
"How are you today?",
]
decoder_prompts = ["", ""]
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
return ModelRequestData(
model_id="facebook/mbart-large-en-ro",
encoder_prompts=encoder_prompts,
decoder_prompts=decoder_prompts,
hf_overrides=hf_overrides,
)
MODEL_GETTERS = {
"bart": get_bart_config,
"mbart": get_mbart_config,
}
def create_all_prompt_types(
encoder_prompts_raw: list,
decoder_prompts_raw: list,
tokenizer,
) -> list:
"""
Generates a list of diverse prompt types for demonstration.
This function is generic and uses the provided raw prompts
to create various vLLM input objects.
"""
text_prompt_raw = encoder_prompts_raw[0]
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
tokens_prompt = TokensPrompt(
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
prompt_token_ids=tokenizer.encode(
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
)
# - Pass a single prompt to encoder/decoder model
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None
single_text_prompt_raw = text_prompt_raw # Pass a string directly
single_text_prompt = text_prompt # Pass a TextPrompt
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
# ruff: noqa: E501
# - Pass explicit encoder and decoder input prompts within one data structure.
# Encoder and decoder prompts can both independently be text or tokens, with
# no requirement that they be the same prompt type. Some example prompt-type
# combinations are shown below, note that these are not exhaustive.
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt string directly, &
# pass decoder prompt tokens
encoder_prompt=single_text_prompt_raw,
decoder_prompt=single_tokens_prompt,
)
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
# Pass TextPrompt to encoder, and
# pass decoder prompt string directly
encoder_prompt=single_text_prompt,
decoder_prompt=single_text_prompt_raw,
)
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt tokens directly, and
# pass TextPrompt to decoder
encoder_prompt=single_tokens_prompt,
decoder_prompt=single_text_prompt,
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
decoder_tokens_prompt = TokensPrompt(
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
)
single_prompt_examples = [
text_prompt_raw,
text_prompt,
tokens_prompt,
]
explicit_pair_examples = [
ExplicitEncoderDecoderPrompt(
encoder_prompt=text_prompt_raw,
decoder_prompt=decoder_tokens_prompt,
),
ExplicitEncoderDecoderPrompt(
encoder_prompt=text_prompt,
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
),
ExplicitEncoderDecoderPrompt(
encoder_prompt=tokens_prompt,
decoder_prompt=text_prompt,
),
]
zipped_prompt_list = zip_enc_dec_prompts(
["An encoder prompt", "Another encoder prompt"],
["A decoder prompt", "Another decoder prompt"],
encoder_prompts_raw,
decoder_prompts_raw,
)
# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
return [
single_text_prompt_raw,
single_text_prompt,
single_tokens_prompt,
enc_dec_prompt1,
enc_dec_prompt2,
enc_dec_prompt3,
] + zipped_prompt_list
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
# Create a sampling params object.
def create_sampling_params():
def create_sampling_params() -> SamplingParams:
"""Create a sampling params object."""
return SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
max_tokens=30,
)
# Print the outputs.
def print_outputs(outputs):
print("-" * 50)
def print_outputs(outputs: list):
"""Formats and prints the generation outputs."""
print("-" * 80)
for i, output in enumerate(outputs):
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Output {i + 1}:")
print(
f"Encoder prompt: {encoder_prompt!r}\n"
f"Decoder prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}"
print(f"Encoder Prompt: {encoder_prompt!r}")
print(f"Decoder Prompt: {prompt!r}")
print(f"Generated Text: {generated_text!r}")
print("-" * 80)
def main(args):
"""Main execution function."""
model_key = args.model
if model_key not in MODEL_GETTERS:
raise ValueError(
f"Unknown model: {model_key}. "
f"Available models: {list(MODEL_GETTERS.keys())}"
)
print("-" * 50)
config_getter = MODEL_GETTERS[model_key]
model_config = config_getter()
def main():
dtype = "float"
# Create a BART encoder/decoder model instance
print(f"🚀 Running demo for model: {model_config.model_id}")
llm = LLM(
model="facebook/bart-large-cnn",
dtype=dtype,
model=model_config.model_id,
dtype="float",
hf_overrides=model_config.hf_overrides,
)
# Get BART tokenizer
tokenizer = llm.llm_engine.get_tokenizer_group()
prompts = create_prompts(tokenizer)
prompts = create_all_prompt_types(
encoder_prompts_raw=model_config.encoder_prompts,
decoder_prompts_raw=model_config.decoder_prompts,
tokenizer=tokenizer,
)
sampling_params = create_sampling_params()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
print_outputs(outputs)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(
description="A flexible demo for vLLM encoder-decoder models."
)
parser.add_argument(
"--model",
"-m",
type=str,
default="bart",
choices=MODEL_GETTERS.keys(),
help="The short name of the model to run.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.sequence import SampleLogprobs
from ....conftest import DecoderPromptType, HfRunner, VllmRunner
from ...utils import check_logprobs_close
def vllm_to_hf_output(
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>"
return output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
prompts: list[dict[str, str]],
decoder_prompt_type: DecoderPromptType,
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
'''
Test the vLLM mBART model by validating it against HuggingFace (HF).
(Docstring content is omitted for brevity)
'''
vllm_prompts = prompts
if decoder_prompt_type == DecoderPromptType.NONE:
vllm_prompts = [{
"encoder_prompt": p['encoder_prompt'],
"decoder_prompt": ""
} for p in prompts]
vllm_kwargs = {
"hf_overrides": {
"architectures": ["MBartForConditionalGeneration"]
}
}
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vllm_kwargs) as vllm_model: # type: ignore
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
vllm_prompts, max_tokens, num_logprobs)
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_kwargs["decoder_start_token_id"] = (
hf_model.tokenizer.lang_code_to_id["ro_RO"])
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, # HF runner still uses the original prompts
max_tokens,
num_logprobs,
**hf_kwargs,
))
hf_skip_tokens = 0
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
@pytest.mark.parametrize(
"model",
[pytest.param("facebook/mbart-large-en-ro")],
)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
run_test(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts[decoder_prompt_type],
decoder_prompt_type,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -316,6 +316,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501
hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {

View File

@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
maybe_prefix)
logger = logging.get_logger(__name__)
@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module):
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states
@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
})
return loaded_params
class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states
class MBartDecoderLayer(BartDecoderLayer):
def forward(
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = decoder_hidden_states
hidden_states = self.self_attn_layer_norm(decoder_hidden_states)
# Self Attention
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MBartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([
MBartEncoderLayer(config,
cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model) # 改动
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states=hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[MBartDecoderLayer(config, cache_config, quant_config,
prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
Returns:
Decoder output torch.Tensor
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for decoder_layer in self.layers:
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = MBartEncoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = MBartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states)
return decoder_outputs
class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
base_model_prefix = "model"
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
assert config.tie_word_embeddings
self.config = config
self.model = MBartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
model_params_dict = dict(self.named_parameters())
loaded_params = set()
remaining_weights = []
shared_embedding_weight = None
for name, loaded_weight in weights:
if any(skip in name
for skip in ["cls.", "pooler.", "final_logits_bias"]):
continue
if any(embed_name in name for embed_name in [
'shared.weight', 'encoder.embed_tokens.weight',
'decoder.embed_tokens.weight'
]):
if shared_embedding_weight is None:
shared_embedding_weight = loaded_weight
continue
is_stacked = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
vllm_name = name
for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items(
):
vllm_name = vllm_name.replace(src, dst)
for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items(
):
if vllm_name.startswith(src):
vllm_name = dst + vllm_name[len(src):]
break
vllm_name = vllm_name.replace(weight_name, param_name)
if vllm_name in model_params_dict:
param = model_params_dict[vllm_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(vllm_name)
is_stacked = True
break
if not is_stacked:
remaining_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."])
auto_loaded_params = loader.load_weights(remaining_weights,
mapper=self.hf_to_vllm_mapper)
loaded_params.update(auto_loaded_params)
if shared_embedding_weight is not None:
lm_head_param = self.lm_head.weight
weight_loader = getattr(lm_head_param, "weight_loader",
default_weight_loader)
weight_loader(lm_head_param, shared_embedding_weight)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
self.model.decoder.embed_tokens.weight = self.lm_head.weight
loaded_params.update({
'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
})
return loaded_params

View File

@ -141,6 +141,7 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
}
_EMBEDDING_MODELS = {