mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Model] Add support for the multi-modal Llama 3.2 model (#8811)
Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chang Su <chang.s.su@oracle.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
4f1ba0844b
commit
770ec6024f
@ -254,6 +254,11 @@ Multimodal Language Models
|
||||
- Image\ :sup:`+`
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
-
|
||||
* - :code:`MllamaForConditionalGeneration`
|
||||
- Llama 3.2
|
||||
- Image
|
||||
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
|
||||
-
|
||||
* - :code:`PaliGemmaForConditionalGeneration`
|
||||
- PaliGemma
|
||||
- Image\ :sup:`E`
|
||||
|
||||
@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# LLama
|
||||
def run_mllama(question, modality):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
|
||||
# Note: The default setting of max_num_seqs (256) and
|
||||
# max_model_len (131072) for this model may cause OOM.
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
|
||||
# The configuration below has been confirmed to launch on a
|
||||
# single H100 GPU.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=16,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
prompt = f"<|image|><|begin_of_text|>{question}"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
@ -256,6 +279,7 @@ model_example_map = {
|
||||
"internvl_chat": run_internvl,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"mllama": run_mllama,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ chat_completion_from_url = client.chat.completions.create(
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?"
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
@ -75,7 +75,7 @@ chat_completion_from_base64 = client.chat.completions.create(
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?"
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
|
||||
@ -4,7 +4,7 @@ numpy < 2.0.0
|
||||
requests
|
||||
tqdm
|
||||
py-cpuinfo
|
||||
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
|
||||
transformers >= 4.45.0 # Required for Llama 3.2.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi < 0.113.0; python_version < '3.9'
|
||||
|
||||
283
tests/models/encoder_decoder/vision_language/test_mllama.py
Normal file
283
tests/models/encoder_decoder/vision_language/test_mllama.py
Normal file
@ -0,0 +1,283 @@
|
||||
from typing import List, Optional, Tuple, Type, overload
|
||||
|
||||
import pytest
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from ....utils import multi_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 1
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|image|><|begin_of_text|>The meaning of the image is",
|
||||
"cherry_blossom":
|
||||
"<|image|><|begin_of_text|>The city is",
|
||||
})
|
||||
|
||||
text_only_prompts = [
|
||||
"The color of the sky is blue but sometimes it can also be",
|
||||
]
|
||||
|
||||
models = [
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
image_token_id = config.image_token_index
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
|
||||
assert output_str[0] == " "
|
||||
hf_output_str = output_str[1:]
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@overload
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
sizes: List[Tuple[int, int]],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: Optional[List[float]] = None,
|
||||
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
if size_factors is not None:
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
elif sizes is not None:
|
||||
inputs_per_image = [(
|
||||
[
|
||||
prompt if size is not None else text_only_prompts[0]
|
||||
for size in sizes
|
||||
],
|
||||
[
|
||||
image.resize(size) if size is not None else None
|
||||
for size in sizes
|
||||
],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
if len(sizes) == 0:
|
||||
inputs_per_image.append(
|
||||
(text_only_prompts, [None] * len(text_only_prompts)))
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
_run_test(hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend)
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_num_seqs=16,
|
||||
max_model_len=4096,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||
}) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
return hf_inputs
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
|
||||
|
||||
# use transformer's MllamaConfig for hf_runner
|
||||
# and vllm's MllamaConfig for vllm_runner
|
||||
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[
|
||||
# Text only
|
||||
[],
|
||||
# Single-size
|
||||
[(512, 512)],
|
||||
# Single-size, batched
|
||||
[(512, 512), (512, 512), (512, 512)],
|
||||
# Multi-size, batched
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028)],
|
||||
# Multi-size, batched, including text only
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||
# to cover all of them
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
|
||||
max_tokens, num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
@ -576,7 +576,9 @@ class ModelConfig:
|
||||
@property
|
||||
def is_encoder_decoder_model(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
return getattr(self.hf_config, "is_encoder_decoder", False)
|
||||
return getattr(self.hf_config, "is_encoder_decoder", False) or (
|
||||
(hasattr(self.hf_config, "text_config") and getattr(
|
||||
self.hf_config.text_config, "is_encoder_decoder", False)))
|
||||
|
||||
@property
|
||||
def is_embedding_model(self) -> bool:
|
||||
|
||||
@ -1734,7 +1734,11 @@ class LLMEngine:
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[LLMInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
if self.is_encoder_decoder_model():
|
||||
if self.model_config.is_multimodal_model:
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
elif self.is_encoder_decoder_model():
|
||||
prompt_ids = inputs.get("encoder_prompt_token_ids")
|
||||
else:
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
|
||||
texts: List[str] = []
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
keep_multimodal_content = \
|
||||
mm_tracker._model_config.hf_config.model_type in \
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||
|
||||
has_image = False
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
|
||||
"will be ignored.")
|
||||
|
||||
mm_parser.parse_image(image_url["url"])
|
||||
has_image = True
|
||||
elif part_type == "audio_url":
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||
text_prompt)
|
||||
if keep_multimodal_content:
|
||||
text_prompt = "\n".join(texts)
|
||||
role_content = [{'type': 'text', 'text': text_prompt}]
|
||||
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
if has_image:
|
||||
role_content = [{'type': 'image'}] + role_content
|
||||
return [ConversationMessage(role=role,
|
||||
content=role_content)] # type: ignore
|
||||
else:
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(
|
||||
mm_placeholder_counts, text_prompt)
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
|
||||
@ -309,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
|
||||
@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
|
||||
available.
|
||||
"""
|
||||
|
||||
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the encoder model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
_T1 = TypeVar("_T1",
|
||||
bound=SingletonPromptInputs,
|
||||
|
||||
@ -128,6 +128,7 @@ class InputPreprocessor:
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]],
|
||||
force_bos: bool = True,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
@ -157,8 +158,8 @@ class InputPreprocessor:
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
if force_bos and (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
@ -295,18 +296,25 @@ class InputPreprocessor:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
||||
|
||||
if encoder_mm_data is not None or decoder_mm_data is not None:
|
||||
raise ValueError("Multi-modal encoder-decoder models are "
|
||||
"not supported yet")
|
||||
if decoder_mm_data is not None:
|
||||
raise ValueError(
|
||||
"Multi-modality decoder inputs of encoder-decoder models are "
|
||||
"not supported yet")
|
||||
|
||||
decoder_prompt_ids = (
|
||||
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
||||
# For Multi-Modal models (e.g., mllama), the text input can be
|
||||
# <|image|><|begin_of_text|>hello world. And we should not add
|
||||
# another <|begin_of_text|> to the beginning.
|
||||
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_prompt_ids,
|
||||
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
|
||||
|
||||
return EncoderDecoderLLMInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
multi_modal_data=decoder_mm_data,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
encoder_multi_modal_data=encoder_mm_data,
|
||||
)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
|
||||
@ -112,6 +112,8 @@ class InputRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
|
||||
DummyDataFactory] = {}
|
||||
self._dummy_encoder_factories_by_model_type: Dict[
|
||||
Type[nn.Module], DummyDataFactory] = {}
|
||||
self._input_processors_by_model_type: Dict[Type[nn.Module],
|
||||
InputProcessor] = {}
|
||||
|
||||
@ -162,11 +164,44 @@ class InputRegistry:
|
||||
return self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
def register_dummy_encoder_data(self, factory: DummyDataFactory):
|
||||
"""
|
||||
Register a dummy encoder data factory to a model class
|
||||
|
||||
This is similar to :meth:`~register_dummy_data`, but for encoder input.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_encoder_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy encoder data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_encoder_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
|
||||
if model_cls in self._dummy_encoder_factories_by_model_type:
|
||||
dummy_factory = self._dummy_encoder_factories_by_model_type[
|
||||
model_cls]
|
||||
else:
|
||||
logger.warning(
|
||||
"No dummy encoder data factory registered to %s. "
|
||||
"Using the dummy data factory for the model instead.",
|
||||
model_cls)
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
return dummy_factory
|
||||
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_registry: "MultiModalRegistry",
|
||||
is_encoder_data: bool = False,
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
@ -184,8 +219,10 @@ class InputRegistry:
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
|
||||
if is_encoder_data:
|
||||
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
|
||||
else:
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||
@ -196,10 +233,15 @@ class InputRegistry:
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = seq_data.prompt_token_ids
|
||||
assert len(num_tokens) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
|
||||
if len(num_tokens) < seq_len:
|
||||
if is_encoder_data:
|
||||
logger.warning(
|
||||
"Expected at least %d dummy encoder tokens for profiling, "
|
||||
"but found %d tokens instead.", seq_len, len(num_tokens))
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
if mm_data is not None:
|
||||
for k, v in mm_data.items():
|
||||
num_items = len(v) if isinstance(v, list) else 1
|
||||
|
||||
@ -101,6 +101,8 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||
"Qwen2VLForConditionalGeneration"),
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"MllamaForConditionalGeneration": ("mllama",
|
||||
"MllamaForConditionalGeneration"),
|
||||
}
|
||||
_CONDITIONAL_GENERATION_MODELS = {
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
|
||||
1135
vllm/model_executor/models/mllama.py
Normal file
1135
vllm/model_executor/models/mllama.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -54,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
if isinstance(nested_tensors, torch.Tensor):
|
||||
return nested_tensors
|
||||
|
||||
if isinstance(nested_tensors, np.ndarray):
|
||||
return torch.from_numpy(nested_tensors)
|
||||
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
|
||||
@ -2,6 +2,7 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.image_processing_base import BatchFeature
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
@ -39,6 +40,10 @@ class ImagePlugin(MultiModalPlugin):
|
||||
) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# Processed by input processor
|
||||
if isinstance(data, BatchFeature):
|
||||
return MultiModalInputs(data.data)
|
||||
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Set, Tuple, Union, cast
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
|
||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -21,7 +22,6 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
@ -471,7 +471,15 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
return self.inputs.get("multi_modal_data") or {}
|
||||
if self.inputs.get("multi_modal_data") and self.inputs.get(
|
||||
"encoder_multi_modal_data"):
|
||||
raise ValueError(
|
||||
"Multi-modal data in both encoder and decoder is not supported."
|
||||
)
|
||||
inputs = self.inputs
|
||||
return self.inputs.get("multi_modal_data") or (cast(
|
||||
EncoderDecoderLLMInputs,
|
||||
inputs).get("encoder_multi_modal_data")) or {}
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
|
||||
@ -22,9 +22,10 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
EAGLEConfig, ExaoneConfig,
|
||||
GraniteConfig, InternVLChatConfig,
|
||||
JAISConfig, MedusaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, RWConfig,
|
||||
SolarConfig, UltravoxConfig)
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
RWConfig, SolarConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
|
||||
@ -37,6 +38,10 @@ MISTRAL_CONFIG_NAME = "params.json"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
|
||||
"mllama": MllamaConfig
|
||||
}
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"chatglm": ChatGLMConfig,
|
||||
"dbrx": DbrxConfig,
|
||||
@ -55,11 +60,15 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
# Granite can be removed from here once we have upgraded to
|
||||
# transformers 4.45+
|
||||
"granite": GraniteConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
with contextlib.suppress(ValueError):
|
||||
AutoConfig.register(name, cls)
|
||||
if name in _CONFIG_REGISTRY_OVERRIDE_HF:
|
||||
AutoConfig.register(name, cls, exist_ok=True)
|
||||
else:
|
||||
AutoConfig.register(name, cls)
|
||||
|
||||
|
||||
class ConfigFormat(str, enum.Enum):
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.granite import GraniteConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
@ -26,6 +27,7 @@ __all__ = [
|
||||
"MedusaConfig",
|
||||
"EAGLEConfig",
|
||||
"ExaoneConfig",
|
||||
"MllamaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"NemotronConfig",
|
||||
"SolarConfig",
|
||||
|
||||
28
vllm/transformers_utils/configs/mllama.py
Normal file
28
vllm/transformers_utils/configs/mllama.py
Normal file
@ -0,0 +1,28 @@
|
||||
from transformers.models.mllama import configuration_mllama as mllama_hf_config
|
||||
|
||||
|
||||
class MllamaTextConfig(mllama_hf_config.MllamaTextConfig):
|
||||
'''
|
||||
Use this class to override is_encoder_decoder:
|
||||
- transformers regards mllama as is_encoder_decoder=False
|
||||
- vllm needs is_encoder_decoder=True to enable cross-attention
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.is_encoder_decoder = True
|
||||
|
||||
|
||||
class MllamaConfig(mllama_hf_config.MllamaConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
text_config = MllamaTextConfig(**text_config)
|
||||
super().__init__(text_config=text_config, **kwargs)
|
||||
@ -111,7 +111,6 @@ def get_tokenizer(
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||
revision=revision)
|
||||
|
||||
@ -18,7 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
|
||||
MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -52,6 +53,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
@ -194,6 +196,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@ -202,6 +206,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
@ -288,8 +294,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
raise NotImplementedError(
|
||||
"Multi-modal encoder-decoder models are not supported yet")
|
||||
logger.info("Starting profile run for multi-modal models.")
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
@ -297,24 +302,39 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, _ = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
decoder_seq_data, decoder_dummy_multi_modal_data \
|
||||
= self.input_registry.dummy_data_for_profiling(
|
||||
self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
self.mm_registry,
|
||||
is_encoder_data=False)
|
||||
encoder_seq_data, encoder_dummy_multi_modal_data \
|
||||
= self.input_registry.dummy_data_for_profiling(
|
||||
self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry,
|
||||
is_encoder_data=True)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
f"but got: {len(decoder_seq_data.prompt_token_ids)}")
|
||||
|
||||
assert decoder_dummy_multi_modal_data is None or \
|
||||
encoder_dummy_multi_modal_data is None, (
|
||||
"Multi-modal data can't be provided in both encoder and decoder"
|
||||
)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
seq_data={group_id: decoder_seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
encoder_seq_data=seq_data,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=None,
|
||||
multi_modal_data=decoder_dummy_multi_modal_data
|
||||
or encoder_dummy_multi_modal_data,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
|
||||
@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario(
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
|
||||
|
||||
if enc_dec_mr.model_config.is_multimodal_model:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
|
||||
|
||||
if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user