[Model] Add support for the multi-modal Llama 3.2 model (#8811)

Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Chen Zhang 2024-09-25 13:29:32 -07:00 committed by GitHub
parent 4f1ba0844b
commit 770ec6024f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1646 additions and 44 deletions

View File

@ -254,6 +254,11 @@ Multimodal Language Models
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`

View File

@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality):
return llm, prompt, stop_token_ids
# LLama
def run_mllama(question, modality):
assert modality == "image"
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# Note: The default setting of max_num_seqs (256) and
# max_model_len (131072) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# The configuration below has been confirmed to launch on a
# single H100 GPU.
llm = LLM(
model=model_name,
max_num_seqs=16,
enforce_eager=True,
)
prompt = f"<|image|><|begin_of_text|>{question}"
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
@ -256,6 +279,7 @@ model_example_map = {
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"mllama": run_mllama,
}

View File

@ -38,7 +38,7 @@ chat_completion_from_url = client.chat.completions.create(
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
},
{
"type": "image_url",
@ -75,7 +75,7 @@ chat_completion_from_base64 = client.chat.completions.create(
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
},
{
"type": "image_url",

View File

@ -4,7 +4,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
transformers >= 4.45.0 # Required for Llama 3.2.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi < 0.113.0; python_version < '3.9'

View File

@ -0,0 +1,283 @@
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})
text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]
models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=16,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding):
return hf_inputs
from transformers import AutoConfig
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
# use transformer's MllamaConfig for hf_runner
# and vllm's MllamaConfig for vllm_runner
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
from vllm.transformers_utils.configs.mllama import MllamaConfig
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
)

View File

@ -576,7 +576,9 @@ class ModelConfig:
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
return getattr(self.hf_config, "is_encoder_decoder", False) or (
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:

View File

@ -1734,7 +1734,11 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
if self.is_encoder_decoder_model():
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
else:
prompt_ids = inputs.get("prompt_token_ids")

View File

@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
def _parse_chat_message_content_parts(
@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
texts: List[str] = []
mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
"will be ignored.")
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]
return [ConversationMessage(role=role, content=text_prompt)]
if has_image:
role_content = [{'type': 'image'}] + role_content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
# No need to validate using Pydantic again

View File

@ -309,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST

View File

@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available.
"""
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,

View File

@ -128,6 +128,7 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
if force_bos and (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
if decoder_mm_data is not None:
raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
# For Multi-Modal models (e.g., mllama), the text input can be
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
)
def _process_encoder_decoder_prompt(

View File

@ -112,6 +112,8 @@ class InputRegistry:
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
@ -162,11 +164,44 @@ class InputRegistry:
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def register_dummy_encoder_data(self, factory: DummyDataFactory):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_encoder_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
@ -184,8 +219,10 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._get_dummy_data_factory(model_cls)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
@ -196,10 +233,15 @@ class InputRegistry:
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1

View File

@ -101,6 +101,8 @@ _MULTIMODAL_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.

View File

@ -2,6 +2,7 @@ from functools import lru_cache
import torch
from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
@ -39,6 +40,10 @@ class ImagePlugin(MultiModalPlugin):
) -> MultiModalInputs:
model_config = ctx.model_config
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalInputs(data.data)
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)

View File

@ -13,6 +13,7 @@ from typing import Set, Tuple, Union, cast
import msgspec
import torch
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
@ -21,7 +22,6 @@ from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@ -471,7 +471,15 @@ class Sequence:
@property
def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.get("multi_modal_data") or {}
if self.inputs.get("multi_modal_data") and self.inputs.get(
"encoder_multi_modal_data"):
raise ValueError(
"Multi-modal data in both encoder and decoder is not supported."
)
inputs = self.inputs
return self.inputs.get("multi_modal_data") or (cast(
EncoderDecoderLLMInputs,
inputs).get("encoder_multi_modal_data")) or {}
@property
def lora_int_id(self) -> int:

View File

@ -22,9 +22,10 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, ExaoneConfig,
GraniteConfig, InternVLChatConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, RWConfig,
SolarConfig, UltravoxConfig)
MllamaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig, SolarConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
@ -37,6 +38,10 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__)
_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
"mllama": MllamaConfig
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
@ -55,11 +60,15 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
# Granite can be removed from here once we have upgraded to
# transformers 4.45+
"granite": GraniteConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
if name in _CONFIG_REGISTRY_OVERRIDE_HF:
AutoConfig.register(name, cls, exist_ok=True)
else:
AutoConfig.register(name, cls)
class ConfigFormat(str, enum.Enum):

View File

@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.granite import GraniteConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
@ -26,6 +27,7 @@ __all__ = [
"MedusaConfig",
"EAGLEConfig",
"ExaoneConfig",
"MllamaConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"SolarConfig",

View File

@ -0,0 +1,28 @@
from transformers.models.mllama import configuration_mllama as mllama_hf_config
class MllamaTextConfig(mllama_hf_config.MllamaTextConfig):
'''
Use this class to override is_encoder_decoder:
- transformers regards mllama as is_encoder_decoder=False
- vllm needs is_encoder_decoder=True to enable cross-attention
'''
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.is_encoder_decoder = True
class MllamaConfig(mllama_hf_config.MllamaConfig):
def __init__(
self,
text_config=None,
**kwargs,
):
if isinstance(text_config, dict):
text_config = MllamaTextConfig(**text_config)
super().__init__(text_config=text_config, **kwargs)

View File

@ -111,7 +111,6 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)

View File

@ -18,7 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
@ -52,6 +53,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
@ -194,6 +196,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
@ -202,6 +206,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
logits = self.model.compute_logits(hidden_or_intermediate_states,
@ -288,8 +294,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet")
logger.info("Starting profile run for multi-modal models.")
batch_size = 0
for group_id in range(max_num_seqs):
@ -297,24 +302,39 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = self.input_registry \
.dummy_data_for_profiling(self.model_config,
decoder_seq_data, decoder_dummy_multi_modal_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry)
self.mm_registry,
is_encoder_data=False)
encoder_seq_data, encoder_dummy_multi_modal_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
f"but got: {len(decoder_seq_data.prompt_token_ids)}")
assert decoder_dummy_multi_modal_data is None or \
encoder_dummy_multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder"
)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
seq_data={group_id: decoder_seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=seq_data,
encoder_seq_data=encoder_seq_data,
cross_block_table=None,
multi_modal_data=decoder_dummy_multi_modal_data
or encoder_dummy_multi_modal_data,
)
seqs.append(seq)

View File

@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
if enc_dec_mr.model_config.is_multimodal_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])