[New Model]mBART model (#22883)

Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
汪志鹏 2025-08-16 20:16:58 +08:00 committed by GitHub
parent 4dff91c93d
commit 829bbd7882
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 717 additions and 92 deletions

View File

@ -330,6 +330,7 @@ th {
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | | `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
!!! note !!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
### Pooling Models ### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models. See [this page](./pooling_models.md) for more information on how to use pooling models.

View File

@ -2,9 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Demonstrate prompting of text-to-text Demonstrate prompting of text-to-text
encoder/decoder models, specifically BART encoder/decoder models, specifically BART and mBART.
This script is refactored to allow model selection via command-line arguments.
""" """
import argparse
from typing import NamedTuple, Optional
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import ( from vllm.inputs import (
ExplicitEncoderDecoderPrompt, ExplicitEncoderDecoderPrompt,
@ -14,119 +19,175 @@ from vllm.inputs import (
) )
def create_prompts(tokenizer): class ModelRequestData(NamedTuple):
# Test prompts """
# Holds the configuration for a specific model, including its
# This section shows all of the valid ways to prompt an HuggingFace ID and the prompts to use for the demo.
# encoder/decoder model. """
#
# - Helpers for building prompts model_id: str
text_prompt_raw = "Hello, my name is" encoder_prompts: list
text_prompt = TextPrompt(prompt="The president of the United States is") decoder_prompts: list
hf_overrides: Optional[dict] = None
def get_bart_config() -> ModelRequestData:
"""
Returns the configuration for facebook/bart-large-cnn.
This uses the exact test cases from the original script.
"""
encoder_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"An encoder prompt",
]
decoder_prompts = [
"A decoder prompt",
"Another decoder prompt",
]
return ModelRequestData(
model_id="facebook/bart-large-cnn",
encoder_prompts=encoder_prompts,
decoder_prompts=decoder_prompts,
)
def get_mbart_config() -> ModelRequestData:
"""
Returns the configuration for facebook/mbart-large-en-ro.
This uses prompts suitable for an English-to-Romanian translation task.
"""
encoder_prompts = [
"The quick brown fox jumps over the lazy dog.",
"How are you today?",
]
decoder_prompts = ["", ""]
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
return ModelRequestData(
model_id="facebook/mbart-large-en-ro",
encoder_prompts=encoder_prompts,
decoder_prompts=decoder_prompts,
hf_overrides=hf_overrides,
)
MODEL_GETTERS = {
"bart": get_bart_config,
"mbart": get_mbart_config,
}
def create_all_prompt_types(
encoder_prompts_raw: list,
decoder_prompts_raw: list,
tokenizer,
) -> list:
"""
Generates a list of diverse prompt types for demonstration.
This function is generic and uses the provided raw prompts
to create various vLLM input objects.
"""
text_prompt_raw = encoder_prompts_raw[0]
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
tokens_prompt = TokensPrompt( tokens_prompt = TokensPrompt(
prompt_token_ids=tokenizer.encode(prompt="The capital of France is") prompt_token_ids=tokenizer.encode(
) encoder_prompts_raw[2 % len(encoder_prompts_raw)]
# - Pass a single prompt to encoder/decoder model )
# (implicitly encoder input prompt);
# decoder input prompt is assumed to be None
single_text_prompt_raw = text_prompt_raw # Pass a string directly
single_text_prompt = text_prompt # Pass a TextPrompt
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
# ruff: noqa: E501
# - Pass explicit encoder and decoder input prompts within one data structure.
# Encoder and decoder prompts can both independently be text or tokens, with
# no requirement that they be the same prompt type. Some example prompt-type
# combinations are shown below, note that these are not exhaustive.
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt string directly, &
# pass decoder prompt tokens
encoder_prompt=single_text_prompt_raw,
decoder_prompt=single_tokens_prompt,
)
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
# Pass TextPrompt to encoder, and
# pass decoder prompt string directly
encoder_prompt=single_text_prompt,
decoder_prompt=single_text_prompt_raw,
)
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
# Pass encoder prompt tokens directly, and
# pass TextPrompt to decoder
encoder_prompt=single_tokens_prompt,
decoder_prompt=single_text_prompt,
) )
# - Finally, here's a useful helper function for zipping encoder and decoder_tokens_prompt = TokensPrompt(
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
# instances )
single_prompt_examples = [
text_prompt_raw,
text_prompt,
tokens_prompt,
]
explicit_pair_examples = [
ExplicitEncoderDecoderPrompt(
encoder_prompt=text_prompt_raw,
decoder_prompt=decoder_tokens_prompt,
),
ExplicitEncoderDecoderPrompt(
encoder_prompt=text_prompt,
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
),
ExplicitEncoderDecoderPrompt(
encoder_prompt=tokens_prompt,
decoder_prompt=text_prompt,
),
]
zipped_prompt_list = zip_enc_dec_prompts( zipped_prompt_list = zip_enc_dec_prompts(
["An encoder prompt", "Another encoder prompt"], encoder_prompts_raw,
["A decoder prompt", "Another decoder prompt"], decoder_prompts_raw,
) )
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
# - Let's put all of the above example prompts together into one list
# which we will pass to the encoder/decoder LLM.
return [
single_text_prompt_raw,
single_text_prompt,
single_tokens_prompt,
enc_dec_prompt1,
enc_dec_prompt2,
enc_dec_prompt3,
] + zipped_prompt_list
# Create a sampling params object. def create_sampling_params() -> SamplingParams:
def create_sampling_params(): """Create a sampling params object."""
return SamplingParams( return SamplingParams(
temperature=0, temperature=0,
top_p=1.0, top_p=1.0,
min_tokens=0, min_tokens=0,
max_tokens=20, max_tokens=30,
) )
# Print the outputs. def print_outputs(outputs: list):
def print_outputs(outputs): """Formats and prints the generation outputs."""
print("-" * 50) print("-" * 80)
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
prompt = output.prompt prompt = output.prompt
encoder_prompt = output.encoder_prompt encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Output {i + 1}:") print(f"Output {i + 1}:")
print( print(f"Encoder Prompt: {encoder_prompt!r}")
f"Encoder prompt: {encoder_prompt!r}\n" print(f"Decoder Prompt: {prompt!r}")
f"Decoder prompt: {prompt!r}\n" print(f"Generated Text: {generated_text!r}")
f"Generated text: {generated_text!r}" print("-" * 80)
def main(args):
"""Main execution function."""
model_key = args.model
if model_key not in MODEL_GETTERS:
raise ValueError(
f"Unknown model: {model_key}. "
f"Available models: {list(MODEL_GETTERS.keys())}"
) )
print("-" * 50) config_getter = MODEL_GETTERS[model_key]
model_config = config_getter()
print(f"🚀 Running demo for model: {model_config.model_id}")
def main():
dtype = "float"
# Create a BART encoder/decoder model instance
llm = LLM( llm = LLM(
model="facebook/bart-large-cnn", model=model_config.model_id,
dtype=dtype, dtype="float",
hf_overrides=model_config.hf_overrides,
) )
# Get BART tokenizer
tokenizer = llm.llm_engine.get_tokenizer_group() tokenizer = llm.llm_engine.get_tokenizer_group()
prompts = create_all_prompt_types(
prompts = create_prompts(tokenizer) encoder_prompts_raw=model_config.encoder_prompts,
decoder_prompts_raw=model_config.decoder_prompts,
tokenizer=tokenizer,
)
sampling_params = create_sampling_params() sampling_params = create_sampling_params()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
print_outputs(outputs) print_outputs(outputs)
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser(
description="A flexible demo for vLLM encoder-decoder models."
)
parser.add_argument(
"--model",
"-m",
type=str,
default="bart",
choices=MODEL_GETTERS.keys(),
help="The short name of the model to run.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.sequence import SampleLogprobs
from ....conftest import DecoderPromptType, HfRunner, VllmRunner
from ...utils import check_logprobs_close
def vllm_to_hf_output(
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
decoder_prompt_type: DecoderPromptType,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "</s>"
return output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
prompts: list[dict[str, str]],
decoder_prompt_type: DecoderPromptType,
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
'''
Test the vLLM mBART model by validating it against HuggingFace (HF).
(Docstring content is omitted for brevity)
'''
vllm_prompts = prompts
if decoder_prompt_type == DecoderPromptType.NONE:
vllm_prompts = [{
"encoder_prompt": p['encoder_prompt'],
"decoder_prompt": ""
} for p in prompts]
vllm_kwargs = {
"hf_overrides": {
"architectures": ["MBartForConditionalGeneration"]
}
}
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vllm_kwargs) as vllm_model: # type: ignore
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
vllm_prompts, max_tokens, num_logprobs)
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_kwargs["decoder_start_token_id"] = (
hf_model.tokenizer.lang_code_to_id["ro_RO"])
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, # HF runner still uses the original prompts
max_tokens,
num_logprobs,
**hf_kwargs,
))
hf_skip_tokens = 0
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
@pytest.mark.parametrize(
"model",
[pytest.param("facebook/mbart-large-en-ro")],
)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
run_test(
hf_runner,
vllm_runner,
example_encoder_decoder_prompts[decoder_prompt_type],
decoder_prompt_type,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -316,6 +316,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"), "BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501
hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501
} }
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {

View File

@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant, SupportsV0Only from .interfaces import SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
maybe_prefix)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module):
if hidden_states.dtype == torch.float16 and ( if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()): or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = cast_overflow_tensors(hidden_states)
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states return hidden_states
@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
}) })
return loaded_params return loaded_params
class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states
class MBartDecoderLayer(BartDecoderLayer):
def forward(
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = decoder_hidden_states
hidden_states = self.self_attn_layer_norm(decoder_hidden_states)
# Self Attention
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MBartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([
MBartEncoderLayer(config,
cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model) # 改动
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states=hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[MBartDecoderLayer(config, cache_config, quant_config,
prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
Returns:
Decoder output torch.Tensor
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for decoder_layer in self.layers:
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = MBartEncoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = MBartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states)
return decoder_outputs
class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
base_model_prefix = "model"
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
assert config.tie_word_embeddings
self.config = config
self.model = MBartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
model_params_dict = dict(self.named_parameters())
loaded_params = set()
remaining_weights = []
shared_embedding_weight = None
for name, loaded_weight in weights:
if any(skip in name
for skip in ["cls.", "pooler.", "final_logits_bias"]):
continue
if any(embed_name in name for embed_name in [
'shared.weight', 'encoder.embed_tokens.weight',
'decoder.embed_tokens.weight'
]):
if shared_embedding_weight is None:
shared_embedding_weight = loaded_weight
continue
is_stacked = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
vllm_name = name
for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items(
):
vllm_name = vllm_name.replace(src, dst)
for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items(
):
if vllm_name.startswith(src):
vllm_name = dst + vllm_name[len(src):]
break
vllm_name = vllm_name.replace(weight_name, param_name)
if vllm_name in model_params_dict:
param = model_params_dict[vllm_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(vllm_name)
is_stacked = True
break
if not is_stacked:
remaining_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."])
auto_loaded_params = loader.load_weights(remaining_weights,
mapper=self.hf_to_vllm_mapper)
loaded_params.update(auto_loaded_params)
if shared_embedding_weight is not None:
lm_head_param = self.lm_head.weight
weight_loader = getattr(lm_head_param, "weight_loader",
default_weight_loader)
weight_loader(lm_head_param, shared_embedding_weight)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
self.model.decoder.embed_tokens.weight = self.lm_head.weight
loaded_params.update({
'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
})
return loaded_params

View File

@ -141,6 +141,7 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder] # [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
} }
_EMBEDDING_MODELS = { _EMBEDDING_MODELS = {