[vlm] Remove vision language config. (#6089)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010 2024-07-03 15:14:16 -07:00 committed by GitHub
parent 3c6325f0fc
commit d9e98f42e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 371 additions and 465 deletions

View File

@ -10,8 +10,13 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
which allows you to pass in multi-modal input alongside text and token prompts.
.. note::
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.
# TODO: Add more instructions on how to do that once embeddings is in.
Guides

View File

@ -8,18 +8,6 @@ vLLM provides experimental support for Vision Language Models (VLMs). This docum
.. important::
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.
Engine Arguments
----------------
The following :ref:`engine arguments <engine_args>` are specific to VLMs:
.. argparse::
:module: vllm.engine.arg_utils
:func: _vlm_engine_args_parser
:prog: -m vllm.entrypoints.openai.api_server
:nodefaultconst:
.. important::
Currently, the support for vision language models on vLLM has the following limitations:
* Only single image input is supported per text prompt.
@ -33,20 +21,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
.. code-block:: python
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
@ -54,19 +39,15 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
.. note::
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.
.. code-block:: python
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Load the image using PIL.Image
image = ...
image = PIL.Image.open(...)
# Single prompt inference
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image},
@ -75,6 +56,26 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Batch inference
image_1 = PIL.Image.open(...)
image_2 = PIL.Image.open(...)
outputs = llm.generate(
[
{
"prompt": "USER: <image>\nWhat is the content of this image?\nASSISTANT:",
"multi_modal_data": {"image": image_1},
},
{
"prompt": "USER: <image>\nWhat's the color of this image?\nASSISTANT:",
"multi_modal_data": {"image": image_2},
}
]
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
@ -99,18 +100,17 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
--chat-template template_llava.jinja
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
To consume the server, you can use the OpenAI client like in the example below:

View File

@ -10,12 +10,7 @@ from vllm import LLM
def run_llava():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"

View File

@ -7,13 +7,7 @@ from vllm import LLM, SamplingParams
def run_llava_next():
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
# Use the maximum possible value for memory profiling
image_feature_size=2928,
)
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=4096)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"

View File

@ -3,9 +3,6 @@
Launch the vLLM server with the following command:
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
--chat-template template_llava.jinja
"""
import base64

View File

@ -14,15 +14,13 @@ def run_phi3v():
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
llm = LLM(
model=model_path,
trust_remote_code=True,
image_token_id=32044,
image_input_shape="1,3,1008,1344",
# Use the maximum possible value for memory profiling
image_feature_size=2653,
max_num_seqs=5,
)

View File

@ -20,9 +20,9 @@ from vllm.utils import cuda_device_count_stateless
model = os.environ["TEST_DIST_MODEL"]
if model.startswith("llava-hf/llava"):
from ..models.test_llava import model_and_vl_config, run_test
from ..models.test_llava import models, run_test
elif model.startswith("microsoft/Phi-3-vision"):
from ..models.test_phi3v import model_and_vl_config, run_test
from ..models.test_phi3v import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
@ -44,7 +44,7 @@ def test_models(hf_runner, vllm_runner, image_assets,
hf_runner,
vllm_runner,
image_assets,
model_and_config=model_and_vl_config[0],
model=models[0],
size_factors=[1.0],
dtype=dtype,
max_tokens=max_tokens,

View File

@ -39,12 +39,6 @@ def server(ray_ctx):
"--max-model-len",
"4096",
"--enforce-eager",
"--image-token-id",
"32000",
"--image-input-shape",
"1,3,336,336",
"--image-feature-size",
"576",
"--chat-template",
str(LLAVA_CHAT_TEMPLATE),
])

View File

@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -21,49 +20,27 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"USER: <image>\nWhat's in this image?\nASSISTANT:",
})
IMAGE_TOKEN_ID = 32000
def iter_llava_configs(model_name: str):
image_hw_to_feature_size = {
(336, 336): 576,
}
for (h, w), f in image_hw_to_feature_size.items():
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape))
model_and_vl_config = [
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
]
models = ["llava-hf/llava-1.5-7b-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "")
assert hf_output_str[0] == " "
hf_output_str = hf_output_str[1:]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
@ -74,7 +51,7 @@ def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
model: str,
*,
size_factors: List[float],
dtype: str,
@ -92,7 +69,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -106,12 +82,11 @@ def run_test(
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
@ -120,7 +95,7 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
@ -136,7 +111,7 @@ def run_test(
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
@ -144,7 +119,7 @@ def run_test(
)
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
@ -161,14 +136,13 @@ def run_test(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -27,46 +26,22 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
})
def iter_llava_next_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = {
(336, 336): 2928,
}
for (h, w), f in image_hw_to_feature_size.items():
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
))
model_and_vl_config = [
*iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"),
]
IMAGE_TOKEN_ID = 32000
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
tokenizer = AutoTokenizer.from_pretrained(model)
image_token_str = tokenizer.decode(IMAGE_TOKEN_ID)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
]
hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
@ -78,7 +53,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"])
@pytest.mark.parametrize(
"size_factors",
[
@ -95,9 +70,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
@ -107,7 +81,6 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -116,11 +89,10 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
@ -129,7 +101,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
for prompts, images in inputs_per_image
]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
@ -145,7 +117,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
@ -23,35 +22,14 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n",
})
def iter_phi3v_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = {
(1008, 1344): 2653,
}
for (h, w), f in image_hw_to_feature_size.items():
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(image_feature_size=f,
image_token_id=32044,
image_input_shape=input_shape))
model_and_vl_config = [
*iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"),
]
models = ["microsoft/Phi-3-vision-128k-instruct"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
output_ids, output_str, out_logprobs = vllm_output
model: str):
"""Sanitize vllm output to be comparable with hf output."""
_, output_str, out_logprobs = vllm_output
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
assert output_str_without_image[0] == " "
@ -60,7 +38,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
hf_output_str = output_str_without_image.replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model)
hf_output_ids = tokenizer.encode(output_str_without_image)
assert hf_output_ids[0] == 1
hf_output_ids = hf_output_ids[1:]
@ -77,7 +55,7 @@ def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
model: str,
*,
size_factors: List[float],
dtype: str,
@ -95,7 +73,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -109,13 +86,13 @@ def run_test(
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
@ -126,7 +103,7 @@ def run_test(
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype,
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
@ -143,7 +120,7 @@ def run_test(
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
@ -153,7 +130,7 @@ def run_test(
# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
@ -170,14 +147,13 @@ def run_test(
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,

View File

@ -1,8 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
import torch
from transformers import PretrainedConfig
@ -120,7 +119,7 @@ class ModelConfig:
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["VisionLanguageConfig"] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -1289,35 +1288,12 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.")
# TODO: To be replaced by MultiModalConfig.
@dataclass
class VisionLanguageConfig:
class MultiModalConfig:
"""Configs the input data format and how models should run for
vision language models."""
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
return result
multimodal models."""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = {
@ -1541,7 +1517,7 @@ class EngineConfig:
device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
multimodal_config: Optional[MultiModalConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]

View File

@ -6,11 +6,11 @@ from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig)
MultiModalConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
from vllm.utils import FlexibleArgumentParser
def nullable_str(val: str):
@ -78,11 +78,6 @@ class EngineArgs:
model_loader_extra_config: Optional[dict] = None
preemption_mode: Optional[str] = None
# Related to Vision-language models such as llava
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
@ -106,27 +101,6 @@ class EngineArgs:
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args_for_vlm(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
return parser
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
@ -484,9 +458,6 @@ class EngineArgs:
],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)
parser.add_argument(
'--scheduler-delay-factor',
type=float,
@ -648,19 +619,7 @@ class EngineArgs:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.image_token_id is not None:
if (not self.image_input_shape or not self.image_feature_size):
raise ValueError(
'Specify `image_input_shape` and '
'`image_feature_size` together with `image_token_id`.')
vision_language_config = VisionLanguageConfig(
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
)
else:
vision_language_config = None
multimodal_config = MultiModalConfig()
device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
@ -685,7 +644,7 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
multimodal_config=vision_language_config)
multimodal_config=multimodal_config)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
@ -787,7 +746,7 @@ class EngineArgs:
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
decoding_config=decoding_config,
@ -831,7 +790,3 @@ def _engine_args_parser():
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True)
def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())

View File

@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
@ -87,8 +87,8 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
multimodal_config (Optional): The configuration related to multimodal
models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
@ -157,7 +157,7 @@ class LLMEngine:
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
@ -215,7 +215,7 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
@ -247,7 +247,7 @@ class LLMEngine:
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
)

View File

@ -121,6 +121,11 @@ class LLM:
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
"image_input_shape", "image_input_type")
if any(k in kwargs for k in removed_vision_keys):
raise TypeError(
"There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,

View File

@ -109,23 +109,12 @@ class OpenAIServingChat(OpenAIServing):
"paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return self.tokenizer.decode(
self.model_config.hf_config.image_token_index)
# The default behaviour assumes that the image token is
# available to the tokenizer.
# (Suitable for LLaVA, Idefics2, DeepSeek-VL)
vlm_config = self.model_config.multimodal_config
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
image_token_id = vlm_config.image_token_id
if vlm_config.image_token_id is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model does not specify an image token.")
return self.tokenizer.decode(image_token_id)
else:
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)

View File

@ -46,7 +46,7 @@ class CPUExecutor(ExecutorBase):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)

View File

@ -3,8 +3,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
@ -26,7 +26,7 @@ class ExecutorBase(ABC):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
@ -36,7 +36,7 @@ class ExecutorBase(ABC):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self._init_executor()
@ -120,7 +120,7 @@ class ExecutorAsyncBase(ExecutorBase):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
# This locks each pipeline parallel stage so multiple virtual engines
@ -132,8 +132,7 @@ class ExecutorAsyncBase(ExecutorBase):
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, vision_language_config,
speculative_config)
lora_config, multimodal_config, speculative_config)
@abstractmethod
async def execute_model_async(

View File

@ -43,7 +43,7 @@ class GPUExecutor(ExecutorBase):
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),

View File

@ -47,7 +47,7 @@ class OpenVINOExecutor(ExecutorBase):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)

View File

@ -7,8 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
@ -43,7 +43,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
@ -57,7 +57,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
placement_group = self.parallel_config.placement_group
@ -199,7 +199,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0,
))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

View File

@ -50,7 +50,7 @@ class TPUExecutor(ExecutorBase):
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
is_driver_worker=rank == 0,
)

View File

@ -3,8 +3,8 @@ from typing import List, Optional
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
@ -26,7 +26,7 @@ class XPUExecutor(GPUExecutor):
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
@ -42,7 +42,7 @@ class XPUExecutor(GPUExecutor):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.speculative_config = None
# Instantiate the worker and load the model to GPU.

View File

@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict
from vllm.sequence import SequenceData
@ -30,7 +30,7 @@ class InputContext:
model_config: "ModelConfig"
"""The configuration of the model."""
def get_multimodal_config(self) -> "VisionLanguageConfig":
def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.

View File

@ -3,8 +3,8 @@ from typing import Optional
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
@ -15,13 +15,13 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)

View File

@ -16,8 +16,8 @@ from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
@ -68,7 +68,7 @@ def _get_quantization_config(
def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
@ -84,18 +84,18 @@ def _get_model_initialization_kwargs(
"please open an issue on github.")
if supports_vision(model_class):
if vlm_config is None:
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["vlm_config"] = vlm_config
extra_kwargs["multimodal_config"] = multimodal_config
return extra_kwargs
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
@ -105,7 +105,7 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
model_class, lora_config, multimodal_config))
class BaseModelLoader(ABC):
@ -118,7 +118,7 @@ class BaseModelLoader(ABC):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
@ -258,14 +258,14 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
lora_config, multimodal_config,
cache_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
@ -298,14 +298,14 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
lora_config, multimodal_config,
cache_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
@ -339,7 +339,7 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
@ -352,7 +352,7 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
lora_config, multimodal_config,
cache_config)
model.load_weights(self._get_weights_iterator())
@ -363,7 +363,7 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
@ -377,7 +377,7 @@ class TensorizerLoader(BaseModelLoader):
quant_config = _get_quantization_config(
model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
model_class, lora_config, multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
@ -392,7 +392,7 @@ class TensorizerLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
@ -406,12 +406,10 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
vision_language_config,
lora_config, multimodal_config,
cache_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config,
vision_language_config,
lora_config, multimodal_config,
cache_config)
@staticmethod
@ -494,7 +492,7 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
@ -508,7 +506,7 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
lora_config, multimodal_config,
cache_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
@ -804,14 +802,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
lora_config, multimodal_config,
cache_config)
self._load_weights(model_config, model)

View File

@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.config import LoRAConfig, MultiModalConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -22,7 +22,7 @@ class SupportsVision(Protocol):
MRO of your model class.
"""
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@ -32,7 +32,7 @@ class SupportsVision(Protocol):
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -108,13 +108,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
@ -138,14 +138,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vlm_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
"The expected image tensor shape is batch dimension plus "
"channel, height and width.")
return data
@ -244,7 +243,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
self.config.image_token_index)
input_ids = None
else:

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaNextConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)
# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs(
type="pixel_values",
data=pixel_values,
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
self.config.image_token_index)
input_ids = None
else:

View File

@ -15,7 +15,7 @@
# limitations under the License.
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
@ -24,7 +24,7 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu",
hidden_size=1024,
@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self,
vision_language_config: VisionLanguageConfig,
config: PretrainedConfig,
wte=None) -> None:
def __init__(self, config: PretrainedConfig, wte=None) -> None:
super().__init__(wte)
self.image_token_id = vision_language_config.image_token_id
self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"]
@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id)
new_token_ids.append(_IMAGE_TOKEN_ID)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=multimodal_config.image_token_id,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,
config: PretrainedConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens)
config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
def forward(self,
input_ids: torch.Tensor,

View File

@ -120,3 +120,10 @@ class MultiModalRegistry:
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
return functools.partial(self.map_input, model_config)
def get_num_input_tokens(self):
"""
Get the number of input tokens for profiling purposes.
"""
# TODO: Provide this number on a per model basis.
return 3000

View File

@ -3,8 +3,8 @@ from typing import List, Optional
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -47,7 +47,7 @@ class TP1DraftModelRunner(ModelRunner):
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
if return_hidden_states:
@ -65,7 +65,7 @@ class TP1DraftModelRunner(ModelRunner):
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
return_hidden_states=return_hidden_states,
)

View File

@ -7,8 +7,8 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
@ -79,7 +79,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
@ -93,7 +93,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
@ -120,15 +120,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
multimodal_config=self.multimodal_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def _prepare_prompt(
self,

View File

@ -6,8 +6,8 @@ import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
@ -131,7 +131,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
@ -145,7 +145,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
@ -162,7 +162,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
@ -40,7 +40,7 @@ class EmbeddingModelRunner(
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
):
super().__init__(model_config,
parallel_config,
@ -51,7 +51,7 @@ class EmbeddingModelRunner(
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config)
multimodal_config=multimodal_config)
@torch.inference_mode()
def execute_model(

View File

@ -24,8 +24,8 @@ except ImportError:
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
@ -36,7 +36,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
@ -171,7 +172,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
self.model_config = model_config
@ -182,7 +183,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
@ -244,7 +245,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
@ -256,6 +257,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_vision(
self.model
), "To be tested: vision language model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
@ -804,12 +808,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size))
if supports_vision(self.model):
max_num_seqs = max(
1,
min(
max_num_seqs,
int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +

View File

@ -7,8 +7,8 @@ from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model
@ -48,7 +48,7 @@ class OpenVINOModelRunner:
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
@ -60,7 +60,7 @@ class OpenVINOModelRunner:
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker

View File

@ -7,8 +7,8 @@ import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@ -148,7 +148,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
@ -162,7 +162,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
@ -180,7 +180,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)

View File

@ -8,7 +8,7 @@ import torch_xla.core.xla_model as xm
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
MultiModalConfig, ParallelConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -39,7 +39,7 @@ class TPUModelRunner:
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
is_driver_worker: bool = False,
):
self.model_config = model_config
@ -48,7 +48,7 @@ class TPUModelRunner:
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
@ -82,7 +82,7 @@ class TPUModelRunner:
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
lora_config=None,
)
xm.wait_device_ops()

View File

@ -8,7 +8,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
MultiModalConfig, ParallelConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
@ -31,7 +31,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
local_rank: int,
rank: int,
distributed_init_method: str,
@ -43,7 +43,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
@ -62,7 +62,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
device_config,
cache_config,
load_config,
vision_language_config,
multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None:

View File

@ -7,8 +7,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
@ -43,7 +43,7 @@ class Worker(LocalOrDistributedWorkerBase):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
@ -66,10 +66,7 @@ class Worker(LocalOrDistributedWorkerBase):
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.vision_language_config = vision_language_config
if self.vision_language_config:
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
self.multimodal_config = multimodal_config
# Return hidden states from target model if the draft model is an
# mlp_speculator
@ -94,7 +91,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by

View File

@ -7,12 +7,13 @@ import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
@ -85,7 +86,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
@ -97,7 +98,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.lora_config = lora_config
self.load_config = load_config
self.cache_config = cache_config
self.vision_language_config = vision_language_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
self.sliding_window = model_config.get_sliding_window()
@ -134,7 +135,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
@ -165,12 +166,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size))
if supports_vision(self.model):
# TODO: properly inject these numbers from MultiModalRegistry.
# Right now, just use an overly conservative number.
max_num_seqs = max(
1,
min(
max_num_seqs,
int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +

View File

@ -9,8 +9,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
@ -45,7 +45,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False,
) -> None:
@ -66,10 +66,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.vision_language_config = vision_language_config
if self.vision_language_config:
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
self.multimodal_config = multimodal_config
self.model_runner = XPUModelRunner( # type: ignore
model_config,
@ -81,7 +78,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
multimodal_config=multimodal_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.