From 829bbd7882222c85c0ca5a17fbb2f70e543f50ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Sat, 16 Aug 2025 20:16:58 +0800 Subject: [PATCH] [New Model]mBART model (#22883) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- docs/models/supported_models.md | 4 + examples/offline_inference/encoder_decoder.py | 235 +++++---- .../models/language/generation/test_mbart.py | 123 +++++ tests/models/registry.py | 2 + vllm/model_executor/models/bart.py | 444 +++++++++++++++++- vllm/model_executor/models/registry.py | 1 + 6 files changed, 717 insertions(+), 92 deletions(-) create mode 100644 tests/models/language/generation/test_mbart.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index a24fa4bcce33..a514572945c3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -330,6 +330,7 @@ th { | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | +| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | | `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | @@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. +!!! note + Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture. + ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index 0da6fa5c4af5..df6c1eaf4a21 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -2,9 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART +encoder/decoder models, specifically BART and mBART. + +This script is refactored to allow model selection via command-line arguments. """ +import argparse +from typing import NamedTuple, Optional + from vllm import LLM, SamplingParams from vllm.inputs import ( ExplicitEncoderDecoderPrompt, @@ -14,119 +19,175 @@ from vllm.inputs import ( ) -def create_prompts(tokenizer): - # Test prompts - # - # This section shows all of the valid ways to prompt an - # encoder/decoder model. - # - # - Helpers for building prompts - text_prompt_raw = "Hello, my name is" - text_prompt = TextPrompt(prompt="The president of the United States is") +class ModelRequestData(NamedTuple): + """ + Holds the configuration for a specific model, including its + HuggingFace ID and the prompts to use for the demo. + """ + + model_id: str + encoder_prompts: list + decoder_prompts: list + hf_overrides: Optional[dict] = None + + +def get_bart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/bart-large-cnn. + This uses the exact test cases from the original script. + """ + encoder_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "An encoder prompt", + ] + decoder_prompts = [ + "A decoder prompt", + "Another decoder prompt", + ] + return ModelRequestData( + model_id="facebook/bart-large-cnn", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + ) + + +def get_mbart_config() -> ModelRequestData: + """ + Returns the configuration for facebook/mbart-large-en-ro. + This uses prompts suitable for an English-to-Romanian translation task. + """ + encoder_prompts = [ + "The quick brown fox jumps over the lazy dog.", + "How are you today?", + ] + decoder_prompts = ["", ""] + hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} + return ModelRequestData( + model_id="facebook/mbart-large-en-ro", + encoder_prompts=encoder_prompts, + decoder_prompts=decoder_prompts, + hf_overrides=hf_overrides, + ) + + +MODEL_GETTERS = { + "bart": get_bart_config, + "mbart": get_mbart_config, +} + + +def create_all_prompt_types( + encoder_prompts_raw: list, + decoder_prompts_raw: list, + tokenizer, +) -> list: + """ + Generates a list of diverse prompt types for demonstration. + This function is generic and uses the provided raw prompts + to create various vLLM input objects. + """ + text_prompt_raw = encoder_prompts_raw[0] + text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(prompt="The capital of France is") - ) - # - Pass a single prompt to encoder/decoder model - # (implicitly encoder input prompt); - # decoder input prompt is assumed to be None - - single_text_prompt_raw = text_prompt_raw # Pass a string directly - single_text_prompt = text_prompt # Pass a TextPrompt - single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - - # ruff: noqa: E501 - # - Pass explicit encoder and decoder input prompts within one data structure. - # Encoder and decoder prompts can both independently be text or tokens, with - # no requirement that they be the same prompt type. Some example prompt-type - # combinations are shown below, note that these are not exhaustive. - - enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, - ) - enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, - ) - enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, + prompt_token_ids=tokenizer.encode( + encoder_prompts_raw[2 % len(encoder_prompts_raw)] + ) ) - # - Finally, here's a useful helper function for zipping encoder and - # decoder prompts together into a list of ExplicitEncoderDecoderPrompt - # instances + decoder_tokens_prompt = TokensPrompt( + prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) + ) + single_prompt_examples = [ + text_prompt_raw, + text_prompt, + tokens_prompt, + ] + explicit_pair_examples = [ + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt_raw, + decoder_prompt=decoder_tokens_prompt, + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=text_prompt, + decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], + ), + ExplicitEncoderDecoderPrompt( + encoder_prompt=tokens_prompt, + decoder_prompt=text_prompt, + ), + ] zipped_prompt_list = zip_enc_dec_prompts( - ["An encoder prompt", "Another encoder prompt"], - ["A decoder prompt", "Another decoder prompt"], + encoder_prompts_raw, + decoder_prompts_raw, ) - - # - Let's put all of the above example prompts together into one list - # which we will pass to the encoder/decoder LLM. - return [ - single_text_prompt_raw, - single_text_prompt, - single_tokens_prompt, - enc_dec_prompt1, - enc_dec_prompt2, - enc_dec_prompt3, - ] + zipped_prompt_list + return single_prompt_examples + explicit_pair_examples + zipped_prompt_list -# Create a sampling params object. -def create_sampling_params(): +def create_sampling_params() -> SamplingParams: + """Create a sampling params object.""" return SamplingParams( temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=30, ) -# Print the outputs. -def print_outputs(outputs): - print("-" * 50) +def print_outputs(outputs: list): + """Formats and prints the generation outputs.""" + print("-" * 80) for i, output in enumerate(outputs): prompt = output.prompt encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text print(f"Output {i + 1}:") - print( - f"Encoder prompt: {encoder_prompt!r}\n" - f"Decoder prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}" + print(f"Encoder Prompt: {encoder_prompt!r}") + print(f"Decoder Prompt: {prompt!r}") + print(f"Generated Text: {generated_text!r}") + print("-" * 80) + + +def main(args): + """Main execution function.""" + model_key = args.model + if model_key not in MODEL_GETTERS: + raise ValueError( + f"Unknown model: {model_key}. " + f"Available models: {list(MODEL_GETTERS.keys())}" ) - print("-" * 50) + config_getter = MODEL_GETTERS[model_key] + model_config = config_getter() - -def main(): - dtype = "float" - - # Create a BART encoder/decoder model instance + print(f"🚀 Running demo for model: {model_config.model_id}") llm = LLM( - model="facebook/bart-large-cnn", - dtype=dtype, + model=model_config.model_id, + dtype="float", + hf_overrides=model_config.hf_overrides, ) - - # Get BART tokenizer tokenizer = llm.llm_engine.get_tokenizer_group() - - prompts = create_prompts(tokenizer) + prompts = create_all_prompt_types( + encoder_prompts_raw=model_config.encoder_prompts, + decoder_prompts_raw=model_config.decoder_prompts, + tokenizer=tokenizer, + ) sampling_params = create_sampling_params() - - # Generate output tokens from the prompts. The output is a list of - # RequestOutput objects that contain the prompt, generated - # text, and other information. outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser( + description="A flexible demo for vLLM encoder-decoder models." + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="bart", + choices=MODEL_GETTERS.keys(), + help="The short name of the model to run.", + ) + args = parser.parse_args() + main(args) diff --git a/tests/models/language/generation/test_mbart.py b/tests/models/language/generation/test_mbart.py new file mode 100644 index 000000000000..854a72713943 --- /dev/null +++ b/tests/models/language/generation/test_mbart.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs + +from ....conftest import DecoderPromptType, HfRunner, VllmRunner +from ...utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + hf_output_str = output_str + "" + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + prompts: list[dict[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + ''' + Test the vLLM mBART model by validating it against HuggingFace (HF). + (Docstring content is omitted for brevity) + ''' + + vllm_prompts = prompts + if decoder_prompt_type == DecoderPromptType.NONE: + vllm_prompts = [{ + "encoder_prompt": p['encoder_prompt'], + "decoder_prompt": "" + } for p in prompts] + + vllm_kwargs = { + "hf_overrides": { + "architectures": ["MBartForConditionalGeneration"] + } + } + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + **vllm_kwargs) as vllm_model: # type: ignore + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + vllm_prompts, max_tokens, num_logprobs) + + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_kwargs["decoder_start_token_id"] = ( + hf_model.tokenizer.lang_code_to_id["ro_RO"]) + + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, # HF runner still uses the original prompts + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + hf_skip_tokens = 0 + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) + + +@pytest.mark.parametrize( + "model", + [pytest.param("facebook/mbart-large-en-ro")], +) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 10e29e01e8a1..99cf997790fe 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -316,6 +316,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), + "MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501 + hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3d328c88ff6e..32551d8102f3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant, SupportsV0Only -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, + maybe_prefix) logger = logging.get_logger(__name__) @@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module): if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = cast_overflow_tensors(hidden_states) return hidden_states @@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): }) return loaded_params + + +class MBartEncoderLayer(BartEncoderLayer): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + hidden_states = cast_overflow_tensors(hidden_states) + + return hidden_states + + +class MBartDecoderLayer(BartDecoderLayer): + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = decoder_hidden_states + hidden_states = self.self_attn_layer_norm(decoder_hidden_states) + + # Self Attention + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + + # Cross-Attention Block + + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class MBartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([ + MBartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + self.layer_norm = nn.LayerNorm(config.d_model) # 改动 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [MBartDecoderLayer(config, cache_config, quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + Returns: + Decoder output torch.Tensor + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] + + # embed positions + embed_pos = self.embed_positions(decoder_positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class MBartModel(nn.Module, SupportsQuant): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = MBartEncoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + base_model_prefix = "model" + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + "encoder.": "model.encoder.", + "shared.": "model.shared." + }, + orig_to_new_substr={ + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + assert config.tie_word_embeddings + self.config = config + self.model = MBartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + model_params_dict = dict(self.named_parameters()) + loaded_params = set() + remaining_weights = [] + shared_embedding_weight = None + + for name, loaded_weight in weights: + if any(skip in name + for skip in ["cls.", "pooler.", "final_logits_bias"]): + continue + if any(embed_name in name for embed_name in [ + 'shared.weight', 'encoder.embed_tokens.weight', + 'decoder.embed_tokens.weight' + ]): + if shared_embedding_weight is None: + shared_embedding_weight = loaded_weight + continue + is_stacked = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + vllm_name = name + for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items( + ): + vllm_name = vllm_name.replace(src, dst) + for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items( + ): + if vllm_name.startswith(src): + vllm_name = dst + vllm_name[len(src):] + break + vllm_name = vllm_name.replace(weight_name, param_name) + if vllm_name in model_params_dict: + param = model_params_dict[vllm_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(vllm_name) + is_stacked = True + break + if not is_stacked: + remaining_weights.append((name, loaded_weight)) + loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."]) + auto_loaded_params = loader.load_weights(remaining_weights, + mapper=self.hf_to_vllm_mapper) + loaded_params.update(auto_loaded_params) + if shared_embedding_weight is not None: + lm_head_param = self.lm_head.weight + weight_loader = getattr(lm_head_param, "weight_loader", + default_weight_loader) + weight_loader(lm_head_param, shared_embedding_weight) + self.model.encoder.embed_tokens.weight = self.lm_head.weight + self.model.decoder.embed_tokens.weight = self.lm_head.weight + loaded_params.update({ + 'model.encoder.embed_tokens.weight', 'lm_head.weight', + 'model.decoder.embed_tokens.weight' + }) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b817615b4356..109bc1fe5c77 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -141,6 +141,7 @@ _TEXT_GENERATION_MODELS = { # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), + "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"), } _EMBEDDING_MODELS = {