mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:25:01 +08:00
[New Model]mBART model (#22883)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
parent
4dff91c93d
commit
829bbd7882
@ -330,6 +330,7 @@ th {
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
!!! note
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
!!! note
|
||||
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
|
||||
@ -2,9 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically BART
|
||||
encoder/decoder models, specifically BART and mBART.
|
||||
|
||||
This script is refactored to allow model selection via command-line arguments.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
@ -14,119 +19,175 @@ from vllm.inputs import (
|
||||
)
|
||||
|
||||
|
||||
def create_prompts(tokenizer):
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
class ModelRequestData(NamedTuple):
|
||||
"""
|
||||
Holds the configuration for a specific model, including its
|
||||
HuggingFace ID and the prompts to use for the demo.
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
encoder_prompts: list
|
||||
decoder_prompts: list
|
||||
hf_overrides: Optional[dict] = None
|
||||
|
||||
|
||||
def get_bart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/bart-large-cnn.
|
||||
This uses the exact test cases from the original script.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"An encoder prompt",
|
||||
]
|
||||
decoder_prompts = [
|
||||
"A decoder prompt",
|
||||
"Another decoder prompt",
|
||||
]
|
||||
return ModelRequestData(
|
||||
model_id="facebook/bart-large-cnn",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
)
|
||||
|
||||
|
||||
def get_mbart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/mbart-large-en-ro.
|
||||
This uses prompts suitable for an English-to-Romanian translation task.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"How are you today?",
|
||||
]
|
||||
decoder_prompts = ["", ""]
|
||||
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
|
||||
return ModelRequestData(
|
||||
model_id="facebook/mbart-large-en-ro",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
|
||||
|
||||
MODEL_GETTERS = {
|
||||
"bart": get_bart_config,
|
||||
"mbart": get_mbart_config,
|
||||
}
|
||||
|
||||
|
||||
def create_all_prompt_types(
|
||||
encoder_prompts_raw: list,
|
||||
decoder_prompts_raw: list,
|
||||
tokenizer,
|
||||
) -> list:
|
||||
"""
|
||||
Generates a list of diverse prompt types for demonstration.
|
||||
This function is generic and uses the provided raw prompts
|
||||
to create various vLLM input objects.
|
||||
"""
|
||||
text_prompt_raw = encoder_prompts_raw[0]
|
||||
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
|
||||
tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
|
||||
prompt_token_ids=tokenizer.encode(
|
||||
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
|
||||
)
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# ruff: noqa: E501
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
decoder_tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
|
||||
)
|
||||
single_prompt_examples = [
|
||||
text_prompt_raw,
|
||||
text_prompt,
|
||||
tokens_prompt,
|
||||
]
|
||||
explicit_pair_examples = [
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt_raw,
|
||||
decoder_prompt=decoder_tokens_prompt,
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt,
|
||||
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=tokens_prompt,
|
||||
decoder_prompt=text_prompt,
|
||||
),
|
||||
]
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
["An encoder prompt", "Another encoder prompt"],
|
||||
["A decoder prompt", "Another decoder prompt"],
|
||||
encoder_prompts_raw,
|
||||
decoder_prompts_raw,
|
||||
)
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw,
|
||||
single_text_prompt,
|
||||
single_tokens_prompt,
|
||||
enc_dec_prompt1,
|
||||
enc_dec_prompt2,
|
||||
enc_dec_prompt3,
|
||||
] + zipped_prompt_list
|
||||
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
|
||||
|
||||
|
||||
# Create a sampling params object.
|
||||
def create_sampling_params():
|
||||
def create_sampling_params() -> SamplingParams:
|
||||
"""Create a sampling params object."""
|
||||
return SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
max_tokens=30,
|
||||
)
|
||||
|
||||
|
||||
# Print the outputs.
|
||||
def print_outputs(outputs):
|
||||
print("-" * 50)
|
||||
def print_outputs(outputs: list):
|
||||
"""Formats and prints the generation outputs."""
|
||||
print("-" * 80)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i + 1}:")
|
||||
print(
|
||||
f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}"
|
||||
print(f"Encoder Prompt: {encoder_prompt!r}")
|
||||
print(f"Decoder Prompt: {prompt!r}")
|
||||
print(f"Generated Text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main execution function."""
|
||||
model_key = args.model
|
||||
if model_key not in MODEL_GETTERS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_key}. "
|
||||
f"Available models: {list(MODEL_GETTERS.keys())}"
|
||||
)
|
||||
print("-" * 50)
|
||||
config_getter = MODEL_GETTERS[model_key]
|
||||
model_config = config_getter()
|
||||
|
||||
|
||||
def main():
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
print(f"🚀 Running demo for model: {model_config.model_id}")
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
model=model_config.model_id,
|
||||
dtype="float",
|
||||
hf_overrides=model_config.hf_overrides,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
|
||||
prompts = create_prompts(tokenizer)
|
||||
prompts = create_all_prompt_types(
|
||||
encoder_prompts_raw=model_config.encoder_prompts,
|
||||
decoder_prompts_raw=model_config.decoder_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
sampling_params = create_sampling_params()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A flexible demo for vLLM encoder-decoder models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
default="bart",
|
||||
choices=MODEL_GETTERS.keys(),
|
||||
help="The short name of the model to run.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
123
tests/models/language/generation/test_mbart.py
Normal file
123
tests/models/language/generation/test_mbart.py
Normal file
@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import DecoderPromptType, HfRunner, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
hf_output_str = output_str + "</s>"
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
prompts: list[dict[str, str]],
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
) -> None:
|
||||
'''
|
||||
Test the vLLM mBART model by validating it against HuggingFace (HF).
|
||||
(Docstring content is omitted for brevity)
|
||||
'''
|
||||
|
||||
vllm_prompts = prompts
|
||||
if decoder_prompt_type == DecoderPromptType.NONE:
|
||||
vllm_prompts = [{
|
||||
"encoder_prompt": p['encoder_prompt'],
|
||||
"decoder_prompt": ""
|
||||
} for p in prompts]
|
||||
|
||||
vllm_kwargs = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["MBartForConditionalGeneration"]
|
||||
}
|
||||
}
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vllm_kwargs) as vllm_model: # type: ignore
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
vllm_prompts, max_tokens, num_logprobs)
|
||||
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_kwargs["decoder_start_token_id"] = (
|
||||
hf_model.tokenizer.lang_code_to_id["ro_RO"])
|
||||
|
||||
hf_outputs = (
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
prompts, # HF runner still uses the original prompts
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
|
||||
hf_skip_tokens = 0
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[pytest.param("facebook/mbart-large-en-ro")],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
||||
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_encoder_decoder_prompts[decoder_prompt_type],
|
||||
decoder_prompt_type,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -316,6 +316,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||
"MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501
|
||||
hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501
|
||||
}
|
||||
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
|
||||
@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsQuant, SupportsV0Only
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
maybe_prefix)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module):
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states,
|
||||
min=-clamp_value,
|
||||
max=clamp_value)
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
})
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MBartEncoderLayer(BartEncoderLayer):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
hidden_states
|
||||
torch.Tensor of *encoder* input embeddings.
|
||||
Returns:
|
||||
Encoder layer output torch.Tensor
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
fc1_out, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(fc1_out)
|
||||
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any()
|
||||
or torch.isnan(hidden_states).any()):
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MBartDecoderLayer(BartDecoderLayer):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = decoder_hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(decoder_hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross-Attention Block
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
hidden_states = self.encoder_attn(
|
||||
decoder_hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
fc1_out, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(fc1_out)
|
||||
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MBartEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers*
|
||||
self attention layers. Each layer is a [`BartEncoderLayer`].
|
||||
Args:
|
||||
config: BartConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.lora_config = lora_config
|
||||
embed_dim = config.d_model
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||
embed_dim,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MBartEncoderLayer(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.encoder_layers)
|
||||
])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model) # 改动
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
positions
|
||||
Positions of *encoder* input sequence tokens.
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(positions)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states=hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MBartDecoder(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers.
|
||||
Each layer is a [`BartDecoderLayer`]
|
||||
Args:
|
||||
config: BartConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
embed_tokens: Optional[nn.Embedding] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
self.lora_config = lora_config
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[MBartDecoderLayer(config, cache_config, quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}") \
|
||||
for layer_idx in range(config.decoder_layers)])
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
decoder_positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_input_ids
|
||||
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
decoder_positions
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_hidden_states:
|
||||
Tensor of encoder output embeddings
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||
else:
|
||||
decoder_positions = inputs_embeds[:, -1]
|
||||
|
||||
# embed positions
|
||||
embed_pos = self.embed_positions(decoder_positions)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
# decoder layers
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(
|
||||
decoder_hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MBartModel(nn.Module, SupportsQuant):
|
||||
_tied_weights_keys = [
|
||||
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||
]
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.encoder = MBartEncoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.decoder = MBartDecoder(config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.decoder")
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||
Padding will be ignored by default should you
|
||||
provide it.
|
||||
positions
|
||||
Positions of *decoder* input sequence tokens.
|
||||
encoder_input_ids
|
||||
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||
encoder_positions:
|
||||
Positions of *encoder* input sequence tokens.
|
||||
Returns:
|
||||
Model output torch.Tensor
|
||||
"""
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
||||
if encoder_input_ids.numel() > 0:
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_positions=positions,
|
||||
encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
base_model_prefix = "model"
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"decoder.": "model.decoder.",
|
||||
"encoder.": "model.encoder.",
|
||||
"shared.": "model.shared."
|
||||
},
|
||||
orig_to_new_substr={
|
||||
"beta": "bias",
|
||||
"gamma": "weight",
|
||||
"LayerNorm": "layernorm",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
lora_config = vllm_config.lora_config
|
||||
assert config.tie_word_embeddings
|
||||
self.config = config
|
||||
self.model = MBartModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.lm_head = BartParallelLMHead(config.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
model_params_dict = dict(self.named_parameters())
|
||||
loaded_params = set()
|
||||
remaining_weights = []
|
||||
shared_embedding_weight = None
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if any(skip in name
|
||||
for skip in ["cls.", "pooler.", "final_logits_bias"]):
|
||||
continue
|
||||
if any(embed_name in name for embed_name in [
|
||||
'shared.weight', 'encoder.embed_tokens.weight',
|
||||
'decoder.embed_tokens.weight'
|
||||
]):
|
||||
if shared_embedding_weight is None:
|
||||
shared_embedding_weight = loaded_weight
|
||||
continue
|
||||
is_stacked = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
vllm_name = name
|
||||
for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items(
|
||||
):
|
||||
vllm_name = vllm_name.replace(src, dst)
|
||||
for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items(
|
||||
):
|
||||
if vllm_name.startswith(src):
|
||||
vllm_name = dst + vllm_name[len(src):]
|
||||
break
|
||||
vllm_name = vllm_name.replace(weight_name, param_name)
|
||||
if vllm_name in model_params_dict:
|
||||
param = model_params_dict[vllm_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(vllm_name)
|
||||
is_stacked = True
|
||||
break
|
||||
if not is_stacked:
|
||||
remaining_weights.append((name, loaded_weight))
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."])
|
||||
auto_loaded_params = loader.load_weights(remaining_weights,
|
||||
mapper=self.hf_to_vllm_mapper)
|
||||
loaded_params.update(auto_loaded_params)
|
||||
if shared_embedding_weight is not None:
|
||||
lm_head_param = self.lm_head.weight
|
||||
weight_loader = getattr(lm_head_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(lm_head_param, shared_embedding_weight)
|
||||
self.model.encoder.embed_tokens.weight = self.lm_head.weight
|
||||
self.model.decoder.embed_tokens.weight = self.lm_head.weight
|
||||
loaded_params.update({
|
||||
'model.encoder.embed_tokens.weight', 'lm_head.weight',
|
||||
'model.decoder.embed_tokens.weight'
|
||||
})
|
||||
return loaded_params
|
||||
|
||||
@ -141,6 +141,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
"MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user