mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[New Model]mBART model (#22883)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
parent
4dff91c93d
commit
829bbd7882
@ -330,6 +330,7 @@ th {
|
|||||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||||
|
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
|
||||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
|||||||
!!! note
|
!!! note
|
||||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
|
||||||
|
|
||||||
### Pooling Models
|
### Pooling Models
|
||||||
|
|
||||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||||
|
|||||||
@ -2,9 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""
|
"""
|
||||||
Demonstrate prompting of text-to-text
|
Demonstrate prompting of text-to-text
|
||||||
encoder/decoder models, specifically BART
|
encoder/decoder models, specifically BART and mBART.
|
||||||
|
|
||||||
|
This script is refactored to allow model selection via command-line arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.inputs import (
|
from vllm.inputs import (
|
||||||
ExplicitEncoderDecoderPrompt,
|
ExplicitEncoderDecoderPrompt,
|
||||||
@ -14,119 +19,175 @@ from vllm.inputs import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_prompts(tokenizer):
|
class ModelRequestData(NamedTuple):
|
||||||
# Test prompts
|
"""
|
||||||
#
|
Holds the configuration for a specific model, including its
|
||||||
# This section shows all of the valid ways to prompt an
|
HuggingFace ID and the prompts to use for the demo.
|
||||||
# encoder/decoder model.
|
"""
|
||||||
#
|
|
||||||
# - Helpers for building prompts
|
model_id: str
|
||||||
text_prompt_raw = "Hello, my name is"
|
encoder_prompts: list
|
||||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
decoder_prompts: list
|
||||||
|
hf_overrides: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_bart_config() -> ModelRequestData:
|
||||||
|
"""
|
||||||
|
Returns the configuration for facebook/bart-large-cnn.
|
||||||
|
This uses the exact test cases from the original script.
|
||||||
|
"""
|
||||||
|
encoder_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"An encoder prompt",
|
||||||
|
]
|
||||||
|
decoder_prompts = [
|
||||||
|
"A decoder prompt",
|
||||||
|
"Another decoder prompt",
|
||||||
|
]
|
||||||
|
return ModelRequestData(
|
||||||
|
model_id="facebook/bart-large-cnn",
|
||||||
|
encoder_prompts=encoder_prompts,
|
||||||
|
decoder_prompts=decoder_prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mbart_config() -> ModelRequestData:
|
||||||
|
"""
|
||||||
|
Returns the configuration for facebook/mbart-large-en-ro.
|
||||||
|
This uses prompts suitable for an English-to-Romanian translation task.
|
||||||
|
"""
|
||||||
|
encoder_prompts = [
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"How are you today?",
|
||||||
|
]
|
||||||
|
decoder_prompts = ["", ""]
|
||||||
|
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
|
||||||
|
return ModelRequestData(
|
||||||
|
model_id="facebook/mbart-large-en-ro",
|
||||||
|
encoder_prompts=encoder_prompts,
|
||||||
|
decoder_prompts=decoder_prompts,
|
||||||
|
hf_overrides=hf_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_GETTERS = {
|
||||||
|
"bart": get_bart_config,
|
||||||
|
"mbart": get_mbart_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_all_prompt_types(
|
||||||
|
encoder_prompts_raw: list,
|
||||||
|
decoder_prompts_raw: list,
|
||||||
|
tokenizer,
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Generates a list of diverse prompt types for demonstration.
|
||||||
|
This function is generic and uses the provided raw prompts
|
||||||
|
to create various vLLM input objects.
|
||||||
|
"""
|
||||||
|
text_prompt_raw = encoder_prompts_raw[0]
|
||||||
|
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
|
||||||
tokens_prompt = TokensPrompt(
|
tokens_prompt = TokensPrompt(
|
||||||
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
|
prompt_token_ids=tokenizer.encode(
|
||||||
)
|
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
|
||||||
# - Pass a single prompt to encoder/decoder model
|
)
|
||||||
# (implicitly encoder input prompt);
|
|
||||||
# decoder input prompt is assumed to be None
|
|
||||||
|
|
||||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
|
||||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
|
||||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
|
||||||
|
|
||||||
# ruff: noqa: E501
|
|
||||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
|
||||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
|
||||||
# no requirement that they be the same prompt type. Some example prompt-type
|
|
||||||
# combinations are shown below, note that these are not exhaustive.
|
|
||||||
|
|
||||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
|
||||||
# Pass encoder prompt string directly, &
|
|
||||||
# pass decoder prompt tokens
|
|
||||||
encoder_prompt=single_text_prompt_raw,
|
|
||||||
decoder_prompt=single_tokens_prompt,
|
|
||||||
)
|
|
||||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
|
||||||
# Pass TextPrompt to encoder, and
|
|
||||||
# pass decoder prompt string directly
|
|
||||||
encoder_prompt=single_text_prompt,
|
|
||||||
decoder_prompt=single_text_prompt_raw,
|
|
||||||
)
|
|
||||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
|
||||||
# Pass encoder prompt tokens directly, and
|
|
||||||
# pass TextPrompt to decoder
|
|
||||||
encoder_prompt=single_tokens_prompt,
|
|
||||||
decoder_prompt=single_text_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Finally, here's a useful helper function for zipping encoder and
|
decoder_tokens_prompt = TokensPrompt(
|
||||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
|
||||||
# instances
|
)
|
||||||
|
single_prompt_examples = [
|
||||||
|
text_prompt_raw,
|
||||||
|
text_prompt,
|
||||||
|
tokens_prompt,
|
||||||
|
]
|
||||||
|
explicit_pair_examples = [
|
||||||
|
ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=text_prompt_raw,
|
||||||
|
decoder_prompt=decoder_tokens_prompt,
|
||||||
|
),
|
||||||
|
ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=text_prompt,
|
||||||
|
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
|
||||||
|
),
|
||||||
|
ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=tokens_prompt,
|
||||||
|
decoder_prompt=text_prompt,
|
||||||
|
),
|
||||||
|
]
|
||||||
zipped_prompt_list = zip_enc_dec_prompts(
|
zipped_prompt_list = zip_enc_dec_prompts(
|
||||||
["An encoder prompt", "Another encoder prompt"],
|
encoder_prompts_raw,
|
||||||
["A decoder prompt", "Another decoder prompt"],
|
decoder_prompts_raw,
|
||||||
)
|
)
|
||||||
|
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
|
||||||
# - Let's put all of the above example prompts together into one list
|
|
||||||
# which we will pass to the encoder/decoder LLM.
|
|
||||||
return [
|
|
||||||
single_text_prompt_raw,
|
|
||||||
single_text_prompt,
|
|
||||||
single_tokens_prompt,
|
|
||||||
enc_dec_prompt1,
|
|
||||||
enc_dec_prompt2,
|
|
||||||
enc_dec_prompt3,
|
|
||||||
] + zipped_prompt_list
|
|
||||||
|
|
||||||
|
|
||||||
# Create a sampling params object.
|
def create_sampling_params() -> SamplingParams:
|
||||||
def create_sampling_params():
|
"""Create a sampling params object."""
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
min_tokens=0,
|
min_tokens=0,
|
||||||
max_tokens=20,
|
max_tokens=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Print the outputs.
|
def print_outputs(outputs: list):
|
||||||
def print_outputs(outputs):
|
"""Formats and prints the generation outputs."""
|
||||||
print("-" * 50)
|
print("-" * 80)
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
encoder_prompt = output.encoder_prompt
|
encoder_prompt = output.encoder_prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Output {i + 1}:")
|
print(f"Output {i + 1}:")
|
||||||
print(
|
print(f"Encoder Prompt: {encoder_prompt!r}")
|
||||||
f"Encoder prompt: {encoder_prompt!r}\n"
|
print(f"Decoder Prompt: {prompt!r}")
|
||||||
f"Decoder prompt: {prompt!r}\n"
|
print(f"Generated Text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}"
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
"""Main execution function."""
|
||||||
|
model_key = args.model
|
||||||
|
if model_key not in MODEL_GETTERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: {model_key}. "
|
||||||
|
f"Available models: {list(MODEL_GETTERS.keys())}"
|
||||||
)
|
)
|
||||||
print("-" * 50)
|
config_getter = MODEL_GETTERS[model_key]
|
||||||
|
model_config = config_getter()
|
||||||
|
|
||||||
|
print(f"🚀 Running demo for model: {model_config.model_id}")
|
||||||
def main():
|
|
||||||
dtype = "float"
|
|
||||||
|
|
||||||
# Create a BART encoder/decoder model instance
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="facebook/bart-large-cnn",
|
model=model_config.model_id,
|
||||||
dtype=dtype,
|
dtype="float",
|
||||||
|
hf_overrides=model_config.hf_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get BART tokenizer
|
|
||||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||||
|
prompts = create_all_prompt_types(
|
||||||
prompts = create_prompts(tokenizer)
|
encoder_prompts_raw=model_config.encoder_prompts,
|
||||||
|
decoder_prompts_raw=model_config.decoder_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
sampling_params = create_sampling_params()
|
sampling_params = create_sampling_params()
|
||||||
|
|
||||||
# Generate output tokens from the prompts. The output is a list of
|
|
||||||
# RequestOutput objects that contain the prompt, generated
|
|
||||||
# text, and other information.
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
print_outputs(outputs)
|
print_outputs(outputs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="A flexible demo for vLLM encoder-decoder models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
type=str,
|
||||||
|
default="bart",
|
||||||
|
choices=MODEL_GETTERS.keys(),
|
||||||
|
help="The short name of the model to run.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|||||||
123
tests/models/language/generation/test_mbart.py
Normal file
123
tests/models/language/generation/test_mbart.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ....conftest import DecoderPromptType, HfRunner, VllmRunner
|
||||||
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(
|
||||||
|
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
|
||||||
|
decoder_prompt_type: DecoderPromptType,
|
||||||
|
):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
hf_output_str = output_str + "</s>"
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
prompts: list[dict[str, str]],
|
||||||
|
decoder_prompt_type: DecoderPromptType,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
Test the vLLM mBART model by validating it against HuggingFace (HF).
|
||||||
|
(Docstring content is omitted for brevity)
|
||||||
|
'''
|
||||||
|
|
||||||
|
vllm_prompts = prompts
|
||||||
|
if decoder_prompt_type == DecoderPromptType.NONE:
|
||||||
|
vllm_prompts = [{
|
||||||
|
"encoder_prompt": p['encoder_prompt'],
|
||||||
|
"decoder_prompt": ""
|
||||||
|
} for p in prompts]
|
||||||
|
|
||||||
|
vllm_kwargs = {
|
||||||
|
"hf_overrides": {
|
||||||
|
"architectures": ["MBartForConditionalGeneration"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True,
|
||||||
|
**vllm_kwargs) as vllm_model: # type: ignore
|
||||||
|
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||||
|
vllm_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
hf_kwargs = {
|
||||||
|
"top_k": None,
|
||||||
|
"num_beams": 1,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"early_stopping": False,
|
||||||
|
"no_repeat_ngram_size": None,
|
||||||
|
"min_length": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||||
|
hf_kwargs["decoder_start_token_id"] = (
|
||||||
|
hf_model.tokenizer.lang_code_to_id["ro_RO"])
|
||||||
|
|
||||||
|
hf_outputs = (
|
||||||
|
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
|
prompts, # HF runner still uses the original prompts
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
**hf_kwargs,
|
||||||
|
))
|
||||||
|
|
||||||
|
hf_skip_tokens = 0
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||||
|
for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[pytest.param("facebook/mbart-large-en-ro")],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||||
|
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
||||||
|
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
|
||||||
|
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_encoder_decoder_prompts[decoder_prompt_type],
|
||||||
|
decoder_prompt_type,
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
@ -316,6 +316,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||||
|
"MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501
|
||||||
|
hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_EXAMPLE_MODELS = {
|
_EMBEDDING_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@ -46,7 +46,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsQuant, SupportsV0Only
|
from .interfaces import SupportsQuant, SupportsV0Only
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||||
|
maybe_prefix)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -422,10 +423,7 @@ class BartEncoderLayer(nn.Module):
|
|||||||
if hidden_states.dtype == torch.float16 and (
|
if hidden_states.dtype == torch.float16 and (
|
||||||
torch.isinf(hidden_states).any()
|
torch.isinf(hidden_states).any()
|
||||||
or torch.isnan(hidden_states).any()):
|
or torch.isnan(hidden_states).any()):
|
||||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
hidden_states = cast_overflow_tensors(hidden_states)
|
||||||
hidden_states = torch.clamp(hidden_states,
|
|
||||||
min=-clamp_value,
|
|
||||||
max=clamp_value)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -906,3 +904,439 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
})
|
})
|
||||||
|
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class MBartEncoderLayer(BartEncoderLayer):
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
hidden_states
|
||||||
|
torch.Tensor of *encoder* input embeddings.
|
||||||
|
Returns:
|
||||||
|
Encoder layer output torch.Tensor
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
fc1_out, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(fc1_out)
|
||||||
|
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
if hidden_states.dtype == torch.float16 and (
|
||||||
|
torch.isinf(hidden_states).any()
|
||||||
|
or torch.isnan(hidden_states).any()):
|
||||||
|
hidden_states = cast_overflow_tensors(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MBartDecoderLayer(BartDecoderLayer):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = decoder_hidden_states
|
||||||
|
hidden_states = self.self_attn_layer_norm(decoder_hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Cross-Attention Block
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.encoder_attn(
|
||||||
|
decoder_hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
fc1_out, _ = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(fc1_out)
|
||||||
|
|
||||||
|
hidden_states, _ = self.fc2(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MBartEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of *config.encoder_layers*
|
||||||
|
self attention layers. Each layer is a [`BartEncoderLayer`].
|
||||||
|
Args:
|
||||||
|
config: BartConfig
|
||||||
|
embed_tokens (nn.Embedding): output embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
embed_tokens: Optional[nn.Embedding] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
embed_dim = config.d_model
|
||||||
|
self.max_source_positions = config.max_position_embeddings
|
||||||
|
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||||
|
embed_dim,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
if embed_tokens is not None:
|
||||||
|
self.embed_tokens.weight = embed_tokens.weight
|
||||||
|
|
||||||
|
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
embed_dim,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
MBartEncoderLayer(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}")
|
||||||
|
for layer_idx in range(config.encoder_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.d_model) # 改动
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
positions
|
||||||
|
Positions of *encoder* input sequence tokens.
|
||||||
|
Returns:
|
||||||
|
Decoder output torch.Tensor
|
||||||
|
"""
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
embed_pos = self.embed_positions(positions)
|
||||||
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
|
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(hidden_states=hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MBartDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer decoder consisting of *config.decoder_layers* layers.
|
||||||
|
Each layer is a [`BartDecoderLayer`]
|
||||||
|
Args:
|
||||||
|
config: BartConfig
|
||||||
|
embed_tokens (nn.Embedding): output embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BartConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
embed_tokens: Optional[nn.Embedding] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.max_target_positions = config.max_position_embeddings
|
||||||
|
embed_scale = math.sqrt(
|
||||||
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
if embed_tokens is not None:
|
||||||
|
self.embed_tokens.weight = embed_tokens.weight
|
||||||
|
|
||||||
|
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.d_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[MBartDecoderLayer(config, cache_config, quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}") \
|
||||||
|
for layer_idx in range(config.decoder_layers)])
|
||||||
|
|
||||||
|
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
decoder_input_ids: torch.Tensor,
|
||||||
|
decoder_positions: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
decoder_input_ids
|
||||||
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
decoder_positions
|
||||||
|
Positions of *decoder* input sequence tokens.
|
||||||
|
encoder_hidden_states:
|
||||||
|
Tensor of encoder output embeddings
|
||||||
|
Returns:
|
||||||
|
Decoder output torch.Tensor
|
||||||
|
"""
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||||
|
else:
|
||||||
|
decoder_positions = inputs_embeds[:, -1]
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
embed_pos = self.embed_positions(decoder_positions)
|
||||||
|
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
hidden_states = decoder_layer(
|
||||||
|
decoder_hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MBartModel(nn.Module, SupportsQuant):
|
||||||
|
_tied_weights_keys = [
|
||||||
|
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.encoder = MBartEncoder(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
|
self.decoder = MBartDecoder(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.decoder")
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids
|
||||||
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
||||||
|
Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
positions
|
||||||
|
Positions of *decoder* input sequence tokens.
|
||||||
|
encoder_input_ids
|
||||||
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
||||||
|
encoder_positions:
|
||||||
|
Positions of *encoder* input sequence tokens.
|
||||||
|
Returns:
|
||||||
|
Model output torch.Tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_hidden_states = None
|
||||||
|
|
||||||
|
if encoder_input_ids.numel() > 0:
|
||||||
|
# Run encoder attention if a non-zero number of encoder tokens
|
||||||
|
# are provided as input
|
||||||
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||||
|
positions=encoder_positions)
|
||||||
|
|
||||||
|
# decoder outputs consists of
|
||||||
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
|
decoder_outputs = self.decoder(
|
||||||
|
decoder_input_ids=input_ids,
|
||||||
|
decoder_positions=positions,
|
||||||
|
encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
"decoder.": "model.decoder.",
|
||||||
|
"encoder.": "model.encoder.",
|
||||||
|
"shared.": "model.shared."
|
||||||
|
},
|
||||||
|
orig_to_new_substr={
|
||||||
|
"beta": "bias",
|
||||||
|
"gamma": "weight",
|
||||||
|
"LayerNorm": "layernorm",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
assert config.tie_word_embeddings
|
||||||
|
self.config = config
|
||||||
|
self.model = MBartModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
|
||||||
|
embed_scale = math.sqrt(
|
||||||
|
config.d_model) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
self.lm_head = BartParallelLMHead(config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embed_scale=embed_scale)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
*,
|
||||||
|
encoder_input_ids: torch.Tensor,
|
||||||
|
encoder_positions: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions, encoder_input_ids,
|
||||||
|
encoder_positions)
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
model_params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params = set()
|
||||||
|
remaining_weights = []
|
||||||
|
shared_embedding_weight = None
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if any(skip in name
|
||||||
|
for skip in ["cls.", "pooler.", "final_logits_bias"]):
|
||||||
|
continue
|
||||||
|
if any(embed_name in name for embed_name in [
|
||||||
|
'shared.weight', 'encoder.embed_tokens.weight',
|
||||||
|
'decoder.embed_tokens.weight'
|
||||||
|
]):
|
||||||
|
if shared_embedding_weight is None:
|
||||||
|
shared_embedding_weight = loaded_weight
|
||||||
|
continue
|
||||||
|
is_stacked = False
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
vllm_name = name
|
||||||
|
for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items(
|
||||||
|
):
|
||||||
|
vllm_name = vllm_name.replace(src, dst)
|
||||||
|
for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items(
|
||||||
|
):
|
||||||
|
if vllm_name.startswith(src):
|
||||||
|
vllm_name = dst + vllm_name[len(src):]
|
||||||
|
break
|
||||||
|
vllm_name = vllm_name.replace(weight_name, param_name)
|
||||||
|
if vllm_name in model_params_dict:
|
||||||
|
param = model_params_dict[vllm_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
loaded_params.add(vllm_name)
|
||||||
|
is_stacked = True
|
||||||
|
break
|
||||||
|
if not is_stacked:
|
||||||
|
remaining_weights.append((name, loaded_weight))
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."])
|
||||||
|
auto_loaded_params = loader.load_weights(remaining_weights,
|
||||||
|
mapper=self.hf_to_vllm_mapper)
|
||||||
|
loaded_params.update(auto_loaded_params)
|
||||||
|
if shared_embedding_weight is not None:
|
||||||
|
lm_head_param = self.lm_head.weight
|
||||||
|
weight_loader = getattr(lm_head_param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(lm_head_param, shared_embedding_weight)
|
||||||
|
self.model.encoder.embed_tokens.weight = self.lm_head.weight
|
||||||
|
self.model.decoder.embed_tokens.weight = self.lm_head.weight
|
||||||
|
loaded_params.update({
|
||||||
|
'model.encoder.embed_tokens.weight', 'lm_head.weight',
|
||||||
|
'model.decoder.embed_tokens.weight'
|
||||||
|
})
|
||||||
|
return loaded_params
|
||||||
|
|||||||
@ -141,6 +141,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||||
|
"MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user