mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 10:14:29 +08:00
[vlm] Remove vision language config. (#6089)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
3c6325f0fc
commit
d9e98f42e4
@ -10,8 +10,13 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
|
||||
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
|
||||
which allows you to pass in multi-modal input alongside text and token prompts.
|
||||
|
||||
.. note::
|
||||
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
|
||||
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.
|
||||
|
||||
|
||||
# TODO: Add more instructions on how to do that once embeddings is in.
|
||||
|
||||
Guides
|
||||
|
||||
@ -8,18 +8,6 @@ vLLM provides experimental support for Vision Language Models (VLMs). This docum
|
||||
.. important::
|
||||
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.
|
||||
|
||||
Engine Arguments
|
||||
----------------
|
||||
|
||||
The following :ref:`engine arguments <engine_args>` are specific to VLMs:
|
||||
|
||||
.. argparse::
|
||||
:module: vllm.engine.arg_utils
|
||||
:func: _vlm_engine_args_parser
|
||||
:prog: -m vllm.entrypoints.openai.api_server
|
||||
:nodefaultconst:
|
||||
|
||||
.. important::
|
||||
Currently, the support for vision language models on vLLM has the following limitations:
|
||||
|
||||
* Only single image input is supported per text prompt.
|
||||
@ -33,20 +21,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
)
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
.. important::
|
||||
Currently, you have to specify ``image_feature_size`` to support memory profiling.
|
||||
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
|
||||
The calculation of feature size is specific to the model. For more details, please refer to
|
||||
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
|
||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
|
||||
every model to perform profiling with.
|
||||
|
||||
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
|
||||
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
|
||||
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
|
||||
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
|
||||
with a more accurate profiling strategy in the future.
|
||||
|
||||
|
||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
||||
@ -54,19 +39,15 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
|
||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||
|
||||
.. note::
|
||||
|
||||
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
|
||||
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||
|
||||
# Load the image using PIL.Image
|
||||
image = ...
|
||||
|
||||
image = PIL.Image.open(...)
|
||||
|
||||
# Single prompt inference
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": image},
|
||||
@ -75,6 +56,26 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
# Batch inference
|
||||
image_1 = PIL.Image.open(...)
|
||||
image_2 = PIL.Image.open(...)
|
||||
outputs = llm.generate(
|
||||
[
|
||||
{
|
||||
"prompt": "USER: <image>\nWhat is the content of this image?\nASSISTANT:",
|
||||
"multi_modal_data": {"image": image_1},
|
||||
},
|
||||
{
|
||||
"prompt": "USER: <image>\nWhat's the color of this image?\nASSISTANT:",
|
||||
"multi_modal_data": {"image": image_2},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
|
||||
|
||||
@ -99,18 +100,17 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model llava-hf/llava-1.5-7b-hf \
|
||||
--image-token-id 32000 \
|
||||
--image-input-shape 1,3,336,336 \
|
||||
--image-feature-size 576 \
|
||||
--chat-template template_llava.jinja
|
||||
|
||||
.. important::
|
||||
Currently, you have to specify ``image_feature_size`` to support memory profiling.
|
||||
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
|
||||
The calculation of feature size is specific to the model. For more details, please refer to
|
||||
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
|
||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
|
||||
every model to perform profiling with.
|
||||
|
||||
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
|
||||
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
|
||||
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
|
||||
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
|
||||
with a more accurate profiling strategy in the future.
|
||||
|
||||
To consume the server, you can use the OpenAI client like in the example below:
|
||||
|
||||
|
||||
@ -10,12 +10,7 @@ from vllm import LLM
|
||||
|
||||
|
||||
def run_llava():
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
)
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||
|
||||
|
||||
@ -7,13 +7,7 @@ from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def run_llava_next():
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
# Use the maximum possible value for memory profiling
|
||||
image_feature_size=2928,
|
||||
)
|
||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=4096)
|
||||
|
||||
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
|
||||
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
|
||||
|
||||
@ -3,9 +3,6 @@
|
||||
Launch the vLLM server with the following command:
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model llava-hf/llava-1.5-7b-hf \
|
||||
--image-token-id 32000 \
|
||||
--image-input-shape 1,3,336,336 \
|
||||
--image-feature-size 576 \
|
||||
--chat-template template_llava.jinja
|
||||
"""
|
||||
import base64
|
||||
|
||||
@ -14,15 +14,13 @@ def run_phi3v():
|
||||
|
||||
# Note: The default setting of max_num_seqs (256) and
|
||||
# max_model_len (128k) for this model may cause OOM.
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
|
||||
# In this example, we override max_num_seqs to 5 while
|
||||
# keeping the original context length of 128k.
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
image_token_id=32044,
|
||||
image_input_shape="1,3,1008,1344",
|
||||
# Use the maximum possible value for memory profiling
|
||||
image_feature_size=2653,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
|
||||
@ -20,9 +20,9 @@ from vllm.utils import cuda_device_count_stateless
|
||||
model = os.environ["TEST_DIST_MODEL"]
|
||||
|
||||
if model.startswith("llava-hf/llava"):
|
||||
from ..models.test_llava import model_and_vl_config, run_test
|
||||
from ..models.test_llava import models, run_test
|
||||
elif model.startswith("microsoft/Phi-3-vision"):
|
||||
from ..models.test_phi3v import model_and_vl_config, run_test
|
||||
from ..models.test_phi3v import models, run_test
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported model: {model}")
|
||||
|
||||
@ -44,7 +44,7 @@ def test_models(hf_runner, vllm_runner, image_assets,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config=model_and_vl_config[0],
|
||||
model=models[0],
|
||||
size_factors=[1.0],
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@ -39,12 +39,6 @@ def server(ray_ctx):
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
"--image-token-id",
|
||||
"32000",
|
||||
"--image-input-shape",
|
||||
"1,3,336,336",
|
||||
"--image-feature-size",
|
||||
"576",
|
||||
"--chat-template",
|
||||
str(LLAVA_CHAT_TEMPLATE),
|
||||
])
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Type
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -21,49 +20,27 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"USER: <image>\nWhat's in this image?\nASSISTANT:",
|
||||
})
|
||||
|
||||
IMAGE_TOKEN_ID = 32000
|
||||
|
||||
def iter_llava_configs(model_name: str):
|
||||
image_hw_to_feature_size = {
|
||||
(336, 336): 576,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
input_shape = (1, 3, h, w)
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(image_feature_size=f,
|
||||
image_token_id=32000,
|
||||
image_input_shape=input_shape))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
|
||||
]
|
||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
||||
]
|
||||
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "")
|
||||
assert hf_output_str[0] == " "
|
||||
hf_output_str = hf_output_str[1:]
|
||||
assert output_str[0] == " "
|
||||
hf_output_str = output_str[1:]
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
@ -74,7 +51,7 @@ def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
@ -92,7 +69,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -106,12 +82,11 @@ def run_test(
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -120,7 +95,7 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
@ -136,7 +111,7 @@ def run_test(
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
@ -144,7 +119,7 @@ def run_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
@ -161,14 +136,13 @@ def run_test(
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -27,46 +26,22 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
|
||||
})
|
||||
|
||||
|
||||
def iter_llava_next_configs(model_name: str):
|
||||
# Need to use the max possible feature size for profile_run
|
||||
image_hw_to_feature_size = {
|
||||
(336, 336): 2928,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
input_shape = (1, 3, h, w)
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(
|
||||
image_feature_size=f,
|
||||
image_token_id=32000,
|
||||
image_input_shape=input_shape,
|
||||
))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"),
|
||||
]
|
||||
IMAGE_TOKEN_ID = 32000
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
image_token_str = tokenizer.decode(IMAGE_TOKEN_ID)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
|
||||
]
|
||||
|
||||
hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
|
||||
@ -78,7 +53,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
@ -95,9 +70,8 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -107,7 +81,6 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -116,11 +89,10 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -129,7 +101,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
@ -145,7 +117,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Type
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
@ -23,35 +22,14 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n",
|
||||
})
|
||||
|
||||
|
||||
def iter_phi3v_configs(model_name: str):
|
||||
# Need to use the max possible feature size for profile_run
|
||||
image_hw_to_feature_size = {
|
||||
(1008, 1344): 2653,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
input_shape = (1, 3, h, w)
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(image_feature_size=f,
|
||||
image_token_id=32044,
|
||||
image_input_shape=input_shape))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"),
|
||||
]
|
||||
models = ["microsoft/Phi-3-vision-128k-instruct"]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
_, output_str, out_logprobs = vllm_output
|
||||
|
||||
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
|
||||
assert output_str_without_image[0] == " "
|
||||
@ -60,7 +38,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
hf_output_str = output_str_without_image.replace("<|user|>", "") \
|
||||
.replace("<|end|>\n<|assistant|>", " ")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
hf_output_ids = tokenizer.encode(output_str_without_image)
|
||||
assert hf_output_ids[0] == 1
|
||||
hf_output_ids = hf_output_ids[1:]
|
||||
@ -77,7 +55,7 @@ def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model_and_config: Tuple[str, VisionLanguageConfig],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
@ -95,7 +73,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
@ -109,13 +86,13 @@ def run_test(
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model_id,
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
@ -126,7 +103,7 @@ def run_test(
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||
with hf_runner(model_id, dtype=dtype,
|
||||
with hf_runner(model, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
@ -143,7 +120,7 @@ def run_test(
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, vlm_config, model_id)
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
@ -153,7 +130,7 @@ def run_test(
|
||||
|
||||
# Since we use _attn_implementation="eager" for hf_runner, there is more
|
||||
# significant numerical difference. The basic `logprobs=5` fails to pass.
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
@ -170,14 +147,13 @@ def run_test(
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model_and_config,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -120,7 +119,7 @@ class ModelConfig:
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
multimodal_config: Optional["VisionLanguageConfig"] = None,
|
||||
multimodal_config: Optional["MultiModalConfig"] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -1289,35 +1288,12 @@ class LoRAConfig:
|
||||
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
||||
|
||||
|
||||
# TODO: To be replaced by MultiModalConfig.
|
||||
@dataclass
|
||||
class VisionLanguageConfig:
|
||||
class MultiModalConfig:
|
||||
"""Configs the input data format and how models should run for
|
||||
vision language models."""
|
||||
# The input id corresponding to image token.
|
||||
image_token_id: int
|
||||
# Used for running `run_prefill_max_token`.
|
||||
# For models that support varying resolution, this corresponds to
|
||||
# worst case scenario (biggest supported resolution).
|
||||
image_input_shape: tuple
|
||||
image_feature_size: int
|
||||
|
||||
def as_cli_args_dict(self) -> Dict[str, Any]:
|
||||
"""Flatten vision language config to pure args.
|
||||
|
||||
Compatible with what llm entrypoint expects.
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, enum.Enum):
|
||||
result[f.name] = value.name.lower()
|
||||
elif isinstance(value, tuple):
|
||||
result[f.name] = ",".join([str(item) for item in value])
|
||||
else:
|
||||
result[f.name] = value
|
||||
|
||||
return result
|
||||
multimodal models."""
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
pass
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
@ -1541,7 +1517,7 @@ class EngineConfig:
|
||||
device_config: DeviceConfig
|
||||
load_config: LoadConfig
|
||||
lora_config: Optional[LoRAConfig]
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
multimodal_config: Optional[MultiModalConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
observability_config: Optional[ObservabilityConfig]
|
||||
|
||||
@ -6,11 +6,11 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig,
|
||||
VisionLanguageConfig)
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig,
|
||||
TokenizerPoolConfig)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser, str_to_int_tuple
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
@ -78,11 +78,6 @@ class EngineArgs:
|
||||
model_loader_extra_config: Optional[dict] = None
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
image_token_id: Optional[int] = None
|
||||
image_input_shape: Optional[str] = None
|
||||
image_feature_size: Optional[int] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
@ -106,27 +101,6 @@ class EngineArgs:
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args_for_vlm(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument('--image-token-id',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('Input id for image token.'))
|
||||
parser.add_argument(
|
||||
'--image-input-shape',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help=('The biggest image input shape (worst for memory footprint) '
|
||||
'given an input type. Only used for vLLM\'s profile_run.'))
|
||||
parser.add_argument(
|
||||
'--image-feature-size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=('The image feature size along the context dimension.'))
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
@ -484,9 +458,6 @@ class EngineArgs:
|
||||
],
|
||||
help='Device type for vLLM execution.')
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
parser = EngineArgs.add_cli_args_for_vlm(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--scheduler-delay-factor',
|
||||
type=float,
|
||||
@ -648,19 +619,7 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
"BitsAndBytes load format and QLoRA adapter only support "
|
||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||
if self.image_token_id is not None:
|
||||
if (not self.image_input_shape or not self.image_feature_size):
|
||||
raise ValueError(
|
||||
'Specify `image_input_shape` and '
|
||||
'`image_feature_size` together with `image_token_id`.')
|
||||
|
||||
vision_language_config = VisionLanguageConfig(
|
||||
image_token_id=self.image_token_id,
|
||||
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||
image_feature_size=self.image_feature_size,
|
||||
)
|
||||
else:
|
||||
vision_language_config = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
model_config = ModelConfig(
|
||||
@ -685,7 +644,7 @@ class EngineArgs:
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||
served_model_name=self.served_model_name,
|
||||
multimodal_config=vision_language_config)
|
||||
multimodal_config=multimodal_config)
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@ -787,7 +746,7 @@ class EngineArgs:
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
@ -831,7 +790,3 @@ def _engine_args_parser():
|
||||
def _async_engine_args_parser():
|
||||
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
|
||||
async_args_only=True)
|
||||
|
||||
|
||||
def _vlm_engine_args_parser():
|
||||
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
|
||||
|
||||
@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||
SchedulerOutputs)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -87,8 +87,8 @@ class LLMEngine:
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
lora_config (Optional): The configuration related to serving multi-LoRA.
|
||||
vision_language_config (Optional): The configuration related to vision
|
||||
language models.
|
||||
multimodal_config (Optional): The configuration related to multimodal
|
||||
models.
|
||||
speculative_config (Optional): The configuration related to speculative
|
||||
decoding.
|
||||
executor_class: The model executor class for managing distributed
|
||||
@ -157,7 +157,7 @@ class LLMEngine:
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
@ -215,7 +215,7 @@ class LLMEngine:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
@ -247,7 +247,7 @@ class LLMEngine:
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
)
|
||||
|
||||
@ -121,6 +121,11 @@ class LLM:
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
removed_vision_keys = ("image_token_id", "image_feature_size",
|
||||
"image_input_shape", "image_input_type")
|
||||
if any(k in kwargs for k in removed_vision_keys):
|
||||
raise TypeError(
|
||||
"There is no need to pass vision-related arguments anymore.")
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@ -109,23 +109,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
"paligemma"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return self.tokenizer.decode(
|
||||
self.model_config.hf_config.image_token_index)
|
||||
|
||||
# The default behaviour assumes that the image token is
|
||||
# available to the tokenizer.
|
||||
# (Suitable for LLaVA, Idefics2, DeepSeek-VL)
|
||||
vlm_config = self.model_config.multimodal_config
|
||||
if vlm_config is None:
|
||||
raise ValueError(
|
||||
"'image_url' input is not supported as the loaded "
|
||||
"model is not multimodal.")
|
||||
|
||||
image_token_id = vlm_config.image_token_id
|
||||
if vlm_config.image_token_id is None:
|
||||
raise ValueError(
|
||||
"'image_url' input is not supported as the loaded "
|
||||
"model does not specify an image token.")
|
||||
|
||||
return self.tokenizer.decode(image_token_id)
|
||||
else:
|
||||
raise TypeError("Unknown model type: {model_type}")
|
||||
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
|
||||
@ -46,7 +46,7 @@ class CPUExecutor(ExecutorBase):
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -3,8 +3,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
|
||||
@ -26,7 +26,7 @@ class ExecutorBase(ABC):
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
@ -36,7 +36,7 @@ class ExecutorBase(ABC):
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.speculative_config = speculative_config
|
||||
|
||||
self._init_executor()
|
||||
@ -120,7 +120,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
# This locks each pipeline parallel stage so multiple virtual engines
|
||||
@ -132,8 +132,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
|
||||
super().__init__(model_config, cache_config, parallel_config,
|
||||
scheduler_config, device_config, load_config,
|
||||
lora_config, vision_language_config,
|
||||
speculative_config)
|
||||
lora_config, multimodal_config, speculative_config)
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
|
||||
@ -43,7 +43,7 @@ class GPUExecutor(ExecutorBase):
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
speculative_config=self.speculative_config,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
|
||||
@ -47,7 +47,7 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -7,8 +7,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
@ -43,7 +43,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
@ -57,7 +57,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
@ -199,7 +199,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
is_driver_worker=rank == 0,
|
||||
))
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
@ -50,7 +50,7 @@ class TPUExecutor(ExecutorBase):
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
@ -26,7 +26,7 @@ class XPUExecutor(GPUExecutor):
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
@ -42,7 +42,7 @@ class XPUExecutor(GPUExecutor):
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.speculative_config = None
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from .data import LLMInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import ModelConfig, MultiModalConfig
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
@ -30,7 +30,7 @@ class InputContext:
|
||||
model_config: "ModelConfig"
|
||||
"""The configuration of the model."""
|
||||
|
||||
def get_multimodal_config(self) -> "VisionLanguageConfig":
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
Get the multimodal configuration of the model.
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@ from typing import Optional
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
@ -15,13 +15,13 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
|
||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config)
|
||||
|
||||
@ -16,8 +16,8 @@ from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -68,7 +68,7 @@ def _get_quantization_config(
|
||||
def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module],
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vlm_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs: Dict[str, Any] = {}
|
||||
@ -84,18 +84,18 @@ def _get_model_initialization_kwargs(
|
||||
"please open an issue on github.")
|
||||
|
||||
if supports_vision(model_class):
|
||||
if vlm_config is None:
|
||||
if multimodal_config is None:
|
||||
raise ValueError("Provide vision related configurations "
|
||||
"through LLM entrypoint or engine arguments.")
|
||||
|
||||
extra_kwargs["vlm_config"] = vlm_config
|
||||
extra_kwargs["multimodal_config"] = multimodal_config
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
@ -105,7 +105,7 @@ def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
**_get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config))
|
||||
model_class, lora_config, multimodal_config))
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
@ -118,7 +118,7 @@ class BaseModelLoader(ABC):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
@ -258,14 +258,14 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config.model,
|
||||
@ -298,14 +298,14 @@ class DummyModelLoader(BaseModelLoader):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
@ -339,7 +339,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
@ -352,7 +352,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
@ -363,7 +363,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer.
|
||||
@ -377,7 +377,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, self.load_config)
|
||||
extra_kwargs = _get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config)
|
||||
model_class, lora_config, multimodal_config)
|
||||
extra_kwargs["quant_config"] = quant_config
|
||||
extra_kwargs["cache_config"] = cache_config
|
||||
|
||||
@ -392,7 +392,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
@ -406,12 +406,10 @@ class TensorizerLoader(BaseModelLoader):
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
return self._load_model_serialized(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
return self._load_model_serialized_cpu(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
|
||||
@staticmethod
|
||||
@ -494,7 +492,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
@ -508,7 +506,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
@ -804,14 +802,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config)
|
||||
|
||||
self._load_weights(model_config, model)
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from vllm.config import LoRAConfig, VisionLanguageConfig
|
||||
from vllm.config import LoRAConfig, MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -22,7 +22,7 @@ class SupportsVision(Protocol):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ class SupportsVision(Protocol):
|
||||
class _SupportsVisionType(Protocol):
|
||||
supports_vision: Literal[True]
|
||||
|
||||
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
from transformers import CLIPVisionConfig, LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -108,13 +108,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
@ -138,14 +138,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
|
||||
if list(data.shape)[1:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"{self.vlm_config.image_input_shape[1:]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
f"image_input_shape in engine args.")
|
||||
"The expected image tensor shape is batch dimension plus "
|
||||
"channel, height and width.")
|
||||
|
||||
return data
|
||||
|
||||
@ -244,7 +243,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vlm_config.image_token_id)
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return data
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
|
||||
dim = data.dim()
|
||||
height = width = self.config.vision_config.image_size
|
||||
# All 4d image tensors have the same number of patches,
|
||||
# so data is a 5d batch of these tensors
|
||||
if dim == 5:
|
||||
if list(data.shape)[2:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"Expected pixel value tensor in shape of: (batch size, "
|
||||
f"patch number, 3, {height}, {width}), got {data.shape}"
|
||||
)
|
||||
|
||||
# 4d image tensors have different number of patches,
|
||||
# so data is each individual tensor.
|
||||
elif dim == 4:
|
||||
if list(data.shape)[1:] != [
|
||||
3, self.config.vision_config.image_size,
|
||||
self.config.vision_config.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"Expected pixel value tensor in shape of: (patch "
|
||||
f"number, 3, {height}, {width}), got {data.shape}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid pixel value tensor of shape {data.shape}")
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vlm_config.image_token_id)
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,7 +24,7 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_embed_tokens": "vision_embed_tokens",
|
||||
}
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 32044
|
||||
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
hidden_act="quick_gelu",
|
||||
hidden_size=1024,
|
||||
@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
config: PretrainedConfig,
|
||||
wte=None) -> None:
|
||||
def __init__(self, config: PretrainedConfig, wte=None) -> None:
|
||||
super().__init__(wte)
|
||||
|
||||
self.image_token_id = vision_language_config.image_token_id
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_token_ids: List[int] = []
|
||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
||||
new_token_ids.append(multimodal_config.image_token_id)
|
||||
new_token_ids.append(_IMAGE_TOKEN_ID)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
||||
@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
model_config,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
llm_inputs,
|
||||
image_token_id=multimodal_config.image_token_id,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
if list(data.shape)[2:] != [
|
||||
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"The expected pixel value tensor shape is batch dimension "
|
||||
"plus patch number, channel, height and width.")
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes))
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -120,3 +120,10 @@ class MultiModalRegistry:
|
||||
Create an input mapper (see :meth:`map_input`) for a specific model.
|
||||
"""
|
||||
return functools.partial(self.map_input, model_config)
|
||||
|
||||
def get_num_input_tokens(self):
|
||||
"""
|
||||
Get the number of input tokens for profiling purposes.
|
||||
"""
|
||||
# TODO: Provide this number on a per model basis.
|
||||
return 3000
|
||||
|
||||
@ -3,8 +3,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -47,7 +47,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
if return_hidden_states:
|
||||
@ -65,7 +65,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@ -7,8 +7,8 @@ from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -79,7 +79,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
@ -93,7 +93,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
@ -120,15 +120,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
|
||||
@ -6,8 +6,8 @@ import torch.distributed
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -131,7 +131,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
@ -145,7 +145,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
@ -162,7 +162,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -40,7 +40,7 @@ class EmbeddingModelRunner(
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
):
|
||||
super().__init__(model_config,
|
||||
parallel_config,
|
||||
@ -51,7 +51,7 @@ class EmbeddingModelRunner(
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config)
|
||||
multimodal_config=multimodal_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
||||
@ -24,8 +24,8 @@ except ImportError:
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
@ -36,7 +36,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import supports_lora
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -171,7 +172,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
self.model_config = model_config
|
||||
@ -182,7 +183,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
@ -244,7 +245,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
@ -256,6 +257,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(self.model), "Model does not support LoRA"
|
||||
assert not supports_vision(
|
||||
self.model
|
||||
), "To be tested: vision language model with LoRA settings."
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
@ -804,12 +808,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
model_config = self.model_config
|
||||
vlm_config = self.vision_language_config
|
||||
|
||||
if vlm_config:
|
||||
max_num_seqs = min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
||||
if supports_vision(self.model):
|
||||
max_num_seqs = max(
|
||||
1,
|
||||
min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens /
|
||||
MULTIMODAL_REGISTRY.get_num_input_tokens())))
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
|
||||
@ -7,8 +7,8 @@ from torch import nn
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.openvino import get_model
|
||||
@ -48,7 +48,7 @@ class OpenVINOModelRunner:
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
@ -60,7 +60,7 @@ class OpenVINOModelRunner:
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
|
||||
@ -7,8 +7,8 @@ import torch.distributed
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
@ -148,7 +148,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
@ -162,7 +162,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
@ -180,7 +180,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
MultiModalConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -39,7 +39,7 @@ class TPUModelRunner:
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
self.model_config = model_config
|
||||
@ -48,7 +48,7 @@ class TPUModelRunner:
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
@ -82,7 +82,7 @@ class TPUModelRunner:
|
||||
parallel_config=self.parallel_config,
|
||||
cache_config=self.cache_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
lora_config=None,
|
||||
)
|
||||
xm.wait_device_ops()
|
||||
|
||||
@ -8,7 +8,7 @@ import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
MultiModalConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -31,7 +31,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
@ -43,7 +43,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
@ -62,7 +62,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config,
|
||||
vision_language_config,
|
||||
multimodal_config,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
def init_device(self) -> None:
|
||||
|
||||
@ -7,8 +7,8 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
@ -43,7 +43,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
@ -66,10 +66,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
self.vision_language_config = vision_language_config
|
||||
if self.vision_language_config:
|
||||
assert not self.lora_config, (
|
||||
"To be tested: vision language model with LoRA settings.")
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
@ -94,7 +91,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
**speculative_args,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
|
||||
@ -7,12 +7,13 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import supports_vision
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -85,7 +86,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
@ -97,7 +98,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.cache_config = cache_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
@ -134,7 +135,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
@ -165,12 +166,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
model_config = self.model_config
|
||||
vlm_config = self.vision_language_config
|
||||
|
||||
if vlm_config:
|
||||
max_num_seqs = min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
||||
if supports_vision(self.model):
|
||||
# TODO: properly inject these numbers from MultiModalRegistry.
|
||||
# Right now, just use an overly conservative number.
|
||||
max_num_seqs = max(
|
||||
1,
|
||||
min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens /
|
||||
MULTIMODAL_REGISTRY.get_num_input_tokens())))
|
||||
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
|
||||
@ -9,8 +9,8 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -45,7 +45,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
@ -66,10 +66,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
self.vision_language_config = vision_language_config
|
||||
if self.vision_language_config:
|
||||
assert not self.lora_config, (
|
||||
"To be tested: vision language model with LoRA settings.")
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
model_config,
|
||||
@ -81,7 +78,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
vision_language_config=vision_language_config,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user