mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 07:35:02 +08:00
[Feature] Add vision language model support. (#3042)
This commit is contained in:
parent
f408d05c52
commit
64172a976c
18
.buildkite/download-images.sh
Normal file
18
.buildkite/download-images.sh
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
set -o pipefail
|
||||||
|
|
||||||
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
|
||||||
|
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
|
||||||
|
mkdir -p images
|
||||||
|
cd images
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
|
||||||
|
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
|
||||||
|
|
||||||
|
cd -
|
||||||
@ -39,9 +39,15 @@ steps:
|
|||||||
|
|
||||||
- label: Models Test
|
- label: Models Test
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models --forked
|
- bash ../.buildkite/download-images.sh
|
||||||
|
- pytest -v -s models --ignore=models/test_llava.py --forked
|
||||||
soft_fail: true
|
soft_fail: true
|
||||||
|
|
||||||
|
- label: Llava Test
|
||||||
|
commands:
|
||||||
|
- bash ../.buildkite/download-images.sh
|
||||||
|
- pytest -v -s models/test_llava.py
|
||||||
|
|
||||||
- label: Prefix Caching Test
|
- label: Prefix Caching Test
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s prefix_caching
|
- pytest -v -s prefix_caching
|
||||||
|
|||||||
84
examples/llava_example.py
Normal file
84
examples/llava_example.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
|
|
||||||
|
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||||
|
|
||||||
|
|
||||||
|
def run_llava_pixel_values():
|
||||||
|
llm = LLM(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
image_input_type="pixel_values",
|
||||||
|
image_token_id=32000,
|
||||||
|
image_input_shape="1,3,336,336",
|
||||||
|
image_feature_size=576,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "<image>" * 576 + (
|
||||||
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||||
|
|
||||||
|
# This should be provided by another online or offline component.
|
||||||
|
images = torch.load("images/stop_sign_pixel_values.pt")
|
||||||
|
|
||||||
|
outputs = llm.generate(prompt,
|
||||||
|
multi_modal_data=MultiModalData(
|
||||||
|
type=MultiModalData.Type.IMAGE, data=images))
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
def run_llava_image_features():
|
||||||
|
llm = LLM(
|
||||||
|
model="llava-hf/llava-1.5-7b-hf",
|
||||||
|
image_input_type="image_features",
|
||||||
|
image_token_id=32000,
|
||||||
|
image_input_shape="1,576,1024",
|
||||||
|
image_feature_size=576,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "<image>" * 576 + (
|
||||||
|
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||||
|
|
||||||
|
# This should be provided by another online or offline component.
|
||||||
|
images = torch.load("images/stop_sign_image_features.pt")
|
||||||
|
|
||||||
|
outputs = llm.generate(prompt,
|
||||||
|
multi_modal_data=MultiModalData(
|
||||||
|
type=MultiModalData.Type.IMAGE, data=images))
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.type == "pixel_values":
|
||||||
|
run_llava_pixel_values()
|
||||||
|
else:
|
||||||
|
run_llava_image_features()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Demo on Llava")
|
||||||
|
parser.add_argument("--type",
|
||||||
|
type=str,
|
||||||
|
choices=["pixel_values", "image_features"],
|
||||||
|
default="pixel_values",
|
||||||
|
help="image input type")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# Download from s3
|
||||||
|
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
||||||
|
local_directory = "images"
|
||||||
|
|
||||||
|
# Make sure the local directory exists or create it
|
||||||
|
os.makedirs(local_directory, exist_ok=True)
|
||||||
|
|
||||||
|
# Use AWS CLI to sync the directory
|
||||||
|
subprocess.check_call(
|
||||||
|
["aws", "s3", "sync", s3_bucket_path, local_directory])
|
||||||
|
main(args)
|
||||||
@ -24,6 +24,10 @@ openai
|
|||||||
requests
|
requests
|
||||||
ray
|
ray
|
||||||
peft
|
peft
|
||||||
|
awscli
|
||||||
|
|
||||||
# Benchmarking
|
# Benchmarking
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|
||||||
|
# Multimodal
|
||||||
|
pillow
|
||||||
|
|||||||
@ -3,16 +3,39 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM
|
from PIL import Image
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoProcessor,
|
||||||
|
LlavaForConditionalGeneration)
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
_TEST_DIR = os.path.dirname(__file__)
|
_TEST_DIR = os.path.dirname(__file__)
|
||||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||||
|
|
||||||
|
# Multi modal related
|
||||||
|
_PIXEL_VALUES_FILES = [
|
||||||
|
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||||
|
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
|
||||||
|
]
|
||||||
|
_IMAGE_FEATURES_FILES = [
|
||||||
|
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||||
|
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
|
||||||
|
]
|
||||||
|
_IMAGE_FILES = [
|
||||||
|
os.path.join(_TEST_DIR, "images", filename)
|
||||||
|
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
|
||||||
|
]
|
||||||
|
_IMAGE_PROMPTS = [
|
||||||
|
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||||
|
"<image>\nUSER: What is the season?\nASSISTANT:"
|
||||||
|
]
|
||||||
|
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
|
||||||
|
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
|
||||||
|
|
||||||
|
|
||||||
def _read_prompts(filename: str) -> List[str]:
|
def _read_prompts(filename: str) -> List[str]:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
@ -20,6 +43,39 @@ def _read_prompts(filename: str) -> List[str]:
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_image_prompts() -> List[str]:
|
||||||
|
return _IMAGE_PROMPTS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_images() -> List[Image.Image]:
|
||||||
|
return [Image.open(filename) for filename in _IMAGE_FILES]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def vllm_images(request) -> "torch.Tensor":
|
||||||
|
vision_language_config = request.getfixturevalue("model_and_config")[1]
|
||||||
|
all_images = []
|
||||||
|
if vision_language_config.image_input_type == (
|
||||||
|
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
|
||||||
|
filenames = _IMAGE_FEATURES_FILES
|
||||||
|
else:
|
||||||
|
filenames = _PIXEL_VALUES_FILES
|
||||||
|
for filename in filenames:
|
||||||
|
all_images.append(torch.load(filename))
|
||||||
|
return torch.concat(all_images, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def vllm_image_prompts(request) -> List[str]:
|
||||||
|
vision_language_config = request.getfixturevalue("model_and_config")[1]
|
||||||
|
return [
|
||||||
|
"<image>" * (vision_language_config.image_feature_size - 1) + p
|
||||||
|
for p in _IMAGE_PROMPTS
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_prompts() -> List[str]:
|
def example_prompts() -> List[str]:
|
||||||
prompts = []
|
prompts = []
|
||||||
@ -42,6 +98,10 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
"float": torch.float,
|
"float": torch.float,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_VISION_LANGUAGE_MODELS = {
|
||||||
|
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
@ -53,11 +113,24 @@ class HfRunner:
|
|||||||
) -> None:
|
) -> None:
|
||||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
self.model_name = model_name
|
||||||
|
if model_name not in _VISION_LANGUAGE_MODELS:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
self.processor = None
|
||||||
|
else:
|
||||||
|
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).cuda()
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
if tokenizer_name is None:
|
if tokenizer_name is None:
|
||||||
tokenizer_name = model_name
|
tokenizer_name = model_name
|
||||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||||
@ -65,13 +138,28 @@ class HfRunner:
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
images: Optional[List[Image.Image]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
outputs: List[Tuple[List[int], str]] = []
|
outputs: List[Tuple[List[int], str]] = []
|
||||||
for prompt in prompts:
|
if images:
|
||||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
assert len(prompts) == len(images)
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
if self.model_name not in _VISION_LANGUAGE_MODELS:
|
||||||
|
input_ids = self.tokenizer(prompt,
|
||||||
|
return_tensors="pt").input_ids
|
||||||
|
inputs = {"input_ids": input_ids.cuda()}
|
||||||
|
else:
|
||||||
|
image = images[i] if images else None
|
||||||
|
inputs = self.processor(text=prompt,
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt")
|
||||||
|
inputs = {
|
||||||
|
key: value.cuda() if value is not None else None
|
||||||
|
for key, value in inputs.items()
|
||||||
|
}
|
||||||
output_ids = self.model.generate(
|
output_ids = self.model.generate(
|
||||||
input_ids.cuda(),
|
**inputs,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -88,10 +176,12 @@ class HfRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
images: Optional["torch.Tensor"] = None,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens)
|
max_new_tokens=max_tokens,
|
||||||
|
images=images)
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
output_ids, output_str = outputs[i]
|
output_ids, output_str = outputs[i]
|
||||||
outputs[i] = (output_ids[0], output_str[0])
|
outputs[i] = (output_ids[0], output_str[0])
|
||||||
@ -183,9 +273,16 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
|
images: Optional["torch.Tensor"] = None,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
req_outputs = self.model.generate(prompts,
|
if images is not None:
|
||||||
sampling_params=sampling_params)
|
assert len(prompts) == images.shape[0]
|
||||||
|
req_outputs = self.model.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
|
||||||
|
data=images)
|
||||||
|
if images is not None else None)
|
||||||
outputs = []
|
outputs = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
prompt_str = req_output.prompt
|
prompt_str = req_output.prompt
|
||||||
@ -222,9 +319,10 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
images: Optional[torch.Tensor] = None,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts, greedy_params)
|
outputs = self.generate(prompts, greedy_params, images=images)
|
||||||
return [(output_ids[0], output_str[0])
|
return [(output_ids[0], output_str[0])
|
||||||
for output_ids, output_str in outputs]
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
|
|||||||
110
tests/models/test_llava.py
Normal file
110
tests/models/test_llava.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import gc
|
||||||
|
from dataclasses import fields
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.config import VisionLanguageConfig
|
||||||
|
|
||||||
|
model_and_vl_config = [
|
||||||
|
("llava-hf/llava-1.5-7b-hf",
|
||||||
|
VisionLanguageConfig(
|
||||||
|
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
|
||||||
|
image_feature_size=576,
|
||||||
|
image_token_id=32000,
|
||||||
|
image_input_shape=(1, 3, 336, 336))),
|
||||||
|
("llava-hf/llava-1.5-7b-hf",
|
||||||
|
VisionLanguageConfig(
|
||||||
|
image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES,
|
||||||
|
image_feature_size=576,
|
||||||
|
image_token_id=32000,
|
||||||
|
image_input_shape=(1, 576, 1024)))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def as_dict(vision_language_config: VisionLanguageConfig) -> Dict:
|
||||||
|
"""Flatten vision language config to pure args.
|
||||||
|
|
||||||
|
Compatible with what llm entrypoint expects.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for field in fields(vision_language_config):
|
||||||
|
value = getattr(vision_language_config, field.name)
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
result[field.name] = value.name.lower()
|
||||||
|
elif isinstance(value, tuple):
|
||||||
|
result[field.name] = ",".join([str(item) for item in value])
|
||||||
|
else:
|
||||||
|
result[field.name] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
|
||||||
|
vision_language_config: VisionLanguageConfig,
|
||||||
|
model_id: str):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output.
|
||||||
|
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||||
|
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||||
|
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
image_token_str = tokenizer.decode(vision_language_config.image_token_id)
|
||||||
|
image_token_str_len = len(image_token_str)
|
||||||
|
input_ids, output_str = vllm_output
|
||||||
|
sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config
|
||||||
|
.image_feature_size - 1:]
|
||||||
|
sanitzied_output_str = output_str[vision_language_config.
|
||||||
|
image_feature_size *
|
||||||
|
image_token_str_len:]
|
||||||
|
return sanitized_input_ids, sanitzied_output_str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("worker_use_ray", [False])
|
||||||
|
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||||
|
vllm_image_prompts, vllm_images, model_and_config: tuple,
|
||||||
|
dtype: str, max_tokens: int, worker_use_ray: bool) -> None:
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the raw images as input.
|
||||||
|
For vllm runner, we provide image tensors and corresponding
|
||||||
|
vision language config as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
model_id, vision_language_config = model_and_config
|
||||||
|
hf_model = hf_runner(model_id, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
||||||
|
max_tokens,
|
||||||
|
images=hf_images)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model_id,
|
||||||
|
dtype=dtype,
|
||||||
|
worker_use_ray=worker_use_ray,
|
||||||
|
**as_dict(vision_language_config))
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||||
|
max_tokens,
|
||||||
|
images=vllm_images)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
for i in range(len(hf_image_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
|
||||||
|
vllm_outputs[i], vision_language_config, model_id)
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
@ -25,7 +25,7 @@ MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
|
|||||||
@ -109,7 +109,7 @@ def create_worker(cls: type,
|
|||||||
)
|
)
|
||||||
|
|
||||||
(model_config, cache_config, parallel_config, scheduler_config,
|
(model_config, cache_config, parallel_config, scheduler_config,
|
||||||
device_config, _) = engine_args.create_engine_configs()
|
device_config, _, _) = engine_args.create_engine_configs()
|
||||||
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_ip(), get_open_port())
|
get_ip(), get_open_port())
|
||||||
|
|||||||
@ -35,7 +35,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
prompt_len - 1)
|
prompt_len - 1)
|
||||||
selected_token_start_idx += prompt_len
|
selected_token_start_idx += prompt_len
|
||||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||||
_) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||||
assert return_prompt_lens == prompt_lens
|
assert return_prompt_lens == prompt_lens
|
||||||
|
|
||||||
# Verify input metadata is correct for prompts.
|
# Verify input metadata is correct for prompts.
|
||||||
|
|||||||
@ -11,7 +11,7 @@ def test_swap() -> None:
|
|||||||
dtype="half",
|
dtype="half",
|
||||||
load_format="dummy")
|
load_format="dummy")
|
||||||
(model_config, cache_config, parallel_config, scheduler_config,
|
(model_config, cache_config, parallel_config, scheduler_config,
|
||||||
device_config, _) = engine_args.create_engine_configs()
|
device_config, _, _) = engine_args.create_engine_configs()
|
||||||
cache_config.num_gpu_blocks = 100
|
cache_config.num_gpu_blocks = 100
|
||||||
cache_config.num_cpu_blocks = 100
|
cache_config.num_cpu_blocks = 100
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import enum
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -8,7 +9,7 @@ from packaging.version import Version
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import get_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
|
from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -118,8 +119,9 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision)
|
code_revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||||
max_model_len)
|
max_model_len)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
@ -218,7 +220,7 @@ class ModelConfig:
|
|||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
) -> None:
|
) -> None:
|
||||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
total_num_attention_heads = self.hf_text_config.num_attention_heads
|
||||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||||
if total_num_attention_heads % tensor_parallel_size != 0:
|
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -226,7 +228,7 @@ class ModelConfig:
|
|||||||
" must be divisible by tensor parallel size "
|
" must be divisible by tensor parallel size "
|
||||||
f"({tensor_parallel_size}).")
|
f"({tensor_parallel_size}).")
|
||||||
|
|
||||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -241,22 +243,23 @@ class ModelConfig:
|
|||||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||||
# addition to sliding window size. We check if that field is present
|
# addition to sliding window size. We check if that field is present
|
||||||
# and if it's False, return None.
|
# and if it's False, return None.
|
||||||
if (hasattr(self.hf_config, "use_sliding_window")
|
if (hasattr(self.hf_text_config, "use_sliding_window")
|
||||||
and not self.hf_config.use_sliding_window):
|
and not self.hf_text_config.use_sliding_window):
|
||||||
return None
|
return None
|
||||||
return getattr(self.hf_config, "sliding_window", None)
|
return getattr(self.hf_text_config, "sliding_window", None)
|
||||||
|
|
||||||
def get_vocab_size(self) -> int:
|
def get_vocab_size(self) -> int:
|
||||||
return self.hf_config.vocab_size
|
return self.hf_text_config.vocab_size
|
||||||
|
|
||||||
def get_hidden_size(self) -> int:
|
def get_hidden_size(self) -> int:
|
||||||
return self.hf_config.hidden_size
|
return self.hf_text_config.hidden_size
|
||||||
|
|
||||||
def get_head_size(self) -> int:
|
def get_head_size(self) -> int:
|
||||||
if hasattr(self.hf_config, "head_dim"):
|
if hasattr(self.hf_text_config, "head_dim"):
|
||||||
return self.hf_config.head_dim
|
return self.hf_text_config.head_dim
|
||||||
# FIXME(woosuk): This may not be true for all models.
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return (self.hf_text_config.hidden_size //
|
||||||
|
self.hf_text_config.num_attention_heads)
|
||||||
|
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
"""Returns the total number of KV heads."""
|
"""Returns the total number of KV heads."""
|
||||||
@ -268,7 +271,7 @@ class ModelConfig:
|
|||||||
new_decoder_arch_falcon = (
|
new_decoder_arch_falcon = (
|
||||||
self.hf_config.model_type in falcon_model_types
|
self.hf_config.model_type in falcon_model_types
|
||||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
|
||||||
"multi_query", False):
|
"multi_query", False):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
# Currently, tensor parallelism is not supported in this case.
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
@ -284,13 +287,13 @@ class ModelConfig:
|
|||||||
"multi_query_group_num",
|
"multi_query_group_num",
|
||||||
]
|
]
|
||||||
for attr in attributes:
|
for attr in attributes:
|
||||||
num_kv_heads = getattr(self.hf_config, attr, None)
|
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||||
if num_kv_heads is not None:
|
if num_kv_heads is not None:
|
||||||
return num_kv_heads
|
return num_kv_heads
|
||||||
|
|
||||||
# For non-grouped-query attention models, the number of KV heads is
|
# For non-grouped-query attention models, the number of KV heads is
|
||||||
# equal to the number of attention heads.
|
# equal to the number of attention heads.
|
||||||
return self.hf_config.num_attention_heads
|
return self.hf_text_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
"""Returns the number of KV heads per GPU."""
|
"""Returns the number of KV heads per GPU."""
|
||||||
@ -303,7 +306,7 @@ class ModelConfig:
|
|||||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
|
|
||||||
@ -627,6 +630,48 @@ class LoRAConfig:
|
|||||||
"LoRA is enabled.")
|
"LoRA is enabled.")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisionLanguageConfig:
|
||||||
|
"""Configs the input data format and how models should run for
|
||||||
|
vision language models."""
|
||||||
|
|
||||||
|
class ImageInputType(enum.Enum):
|
||||||
|
"""Image input type into the vision language model.
|
||||||
|
|
||||||
|
An image roughly goes through the following transformation:
|
||||||
|
Raw image --> pixel values --> image features --> image embeddings.
|
||||||
|
|
||||||
|
The difference between different image input types is where the
|
||||||
|
image encoder (pixel values --> image features) is run.
|
||||||
|
Different image input types also correspond to different tensor shapes.
|
||||||
|
|
||||||
|
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
|
||||||
|
IMAGE_FEATURES: (1, 576, 1024).
|
||||||
|
"""
|
||||||
|
PIXEL_VALUES = enum.auto()
|
||||||
|
IMAGE_FEATURES = enum.auto()
|
||||||
|
|
||||||
|
image_input_type: ImageInputType
|
||||||
|
# The input id corresponding to image token.
|
||||||
|
image_token_id: int
|
||||||
|
# Used for running `run_prefill_max_token`.
|
||||||
|
# For models that support varying resolution, this corresponds to
|
||||||
|
# worst case scenario (biggest supported resolution).
|
||||||
|
image_input_shape: tuple
|
||||||
|
image_feature_size: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_image_input_enum_type(
|
||||||
|
cls, value: str) -> "VisionLanguageConfig.ImageInputType":
|
||||||
|
"""Get the image input type from a string."""
|
||||||
|
try:
|
||||||
|
return cls.ImageInputType[value.upper()]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"{value} is not a valid choice. "
|
||||||
|
f"Expecting to choose from "
|
||||||
|
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
|
|||||||
@ -388,6 +388,12 @@ class Scheduler:
|
|||||||
computed_block_nums=self.block_manager.
|
computed_block_nums=self.block_manager.
|
||||||
get_common_computed_block_ids(seq_group),
|
get_common_computed_block_ids(seq_group),
|
||||||
state=seq_group.state,
|
state=seq_group.state,
|
||||||
|
# `multi_modal_data` will only be present for the 1st comm
|
||||||
|
# between engine and worker.
|
||||||
|
# the subsequent comms can still use delta, but
|
||||||
|
# `multi_modal_data` will be None.
|
||||||
|
multi_modal_data=seq_group.multi_modal_data
|
||||||
|
if scheduler_outputs.prompt_run else None,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|||||||
@ -4,7 +4,9 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, TokenizerPoolConfig)
|
ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -50,6 +52,11 @@ class EngineArgs:
|
|||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
|
# Related to Vision-language models such as llava
|
||||||
|
image_input_type: Optional[str] = None
|
||||||
|
image_token_id: Optional[int] = None
|
||||||
|
image_input_shape: Optional[str] = None
|
||||||
|
image_feature_size: Optional[int] = None
|
||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -305,6 +312,31 @@ class EngineArgs:
|
|||||||
default=EngineArgs.device,
|
default=EngineArgs.device,
|
||||||
choices=["auto", "cuda", "neuron"],
|
choices=["auto", "cuda", "neuron"],
|
||||||
help='Device type for vLLM execution.')
|
help='Device type for vLLM execution.')
|
||||||
|
# Related to Vision-language models such as llava
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-input-type',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=[
|
||||||
|
t.name.lower() for t in VisionLanguageConfig.ImageInputType
|
||||||
|
],
|
||||||
|
help=('The image input type passed into vLLM. '
|
||||||
|
'Should be one of "pixel_values" or "image_features".'))
|
||||||
|
parser.add_argument('--image-token-id',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help=('Input id for image token.'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-input-shape',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=('The biggest image input shape (worst for memory footprint) '
|
||||||
|
'given an input type. Only used for vLLM\'s profile_run.'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--image-feature-size',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help=('The image feature size along the context dimension.'))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--scheduler-delay-factor',
|
'--scheduler-delay-factor',
|
||||||
type=float,
|
type=float,
|
||||||
@ -324,7 +356,8 @@ class EngineArgs:
|
|||||||
def create_engine_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||||
DeviceConfig, Optional[LoRAConfig]]:
|
DeviceConfig, Optional[LoRAConfig],
|
||||||
|
Optional[VisionLanguageConfig]]:
|
||||||
device_config = DeviceConfig(self.device)
|
device_config = DeviceConfig(self.device)
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
self.model, self.tokenizer, self.tokenizer_mode,
|
self.model, self.tokenizer, self.tokenizer_mode,
|
||||||
@ -358,8 +391,25 @@ class EngineArgs:
|
|||||||
lora_dtype=self.lora_dtype,
|
lora_dtype=self.lora_dtype,
|
||||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||||
|
|
||||||
|
if self.image_input_type:
|
||||||
|
if (not self.image_token_id or not self.image_input_shape
|
||||||
|
or not self.image_feature_size):
|
||||||
|
raise ValueError(
|
||||||
|
'Specify `image_token_id`, `image_input_shape` and '
|
||||||
|
'`image_feature_size` together with `image_input_type`.')
|
||||||
|
vision_language_config = VisionLanguageConfig(
|
||||||
|
image_input_type=VisionLanguageConfig.
|
||||||
|
get_image_input_enum_type(self.image_input_type),
|
||||||
|
image_token_id=self.image_token_id,
|
||||||
|
image_input_shape=str_to_int_tuple(self.image_input_shape),
|
||||||
|
image_feature_size=self.image_feature_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vision_language_config = None
|
||||||
|
|
||||||
return (model_config, cache_config, parallel_config, scheduler_config,
|
return (model_config, cache_config, parallel_config, scheduler_config,
|
||||||
device_config, lora_config)
|
device_config, lora_config, vision_language_config)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = int(
|
ENGINE_ITERATION_TIMEOUT_S = int(
|
||||||
@ -240,6 +241,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
@ -252,14 +254,13 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
|
||||||
return self.add_request(
|
return self.add_request(request_id,
|
||||||
request_id,
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
@ -486,6 +487,7 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> AsyncStream:
|
) -> AsyncStream:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
shortened_prompt = prompt
|
shortened_prompt = prompt
|
||||||
@ -534,7 +536,9 @@ class AsyncLLMEngine:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data,
|
||||||
|
)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
@ -545,6 +549,7 @@ class AsyncLLMEngine:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncIterator[RequestOutput]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
@ -560,6 +565,7 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
multi_modal_data: Multi modal data per request.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `RequestOutput` objects from the LLMEngine for the
|
The output `RequestOutput` objects from the LLMEngine for the
|
||||||
@ -619,6 +625,7 @@ class AsyncLLMEngine:
|
|||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for request_output in stream:
|
async for request_output in stream:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics import StatLogger, Stats
|
from vllm.engine.metrics import StatLogger, Stats
|
||||||
@ -15,8 +15,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||||
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
@ -62,6 +63,7 @@ class LLMEngine:
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional["VisionLanguageConfig"],
|
||||||
executor_class: Type[ExecutorBase],
|
executor_class: Type[ExecutorBase],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -90,6 +92,7 @@ class LLMEngine:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
@ -102,7 +105,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.model_executor = executor_class(model_config, cache_config,
|
self.model_executor = executor_class(model_config, cache_config,
|
||||||
parallel_config, scheduler_config,
|
parallel_config, scheduler_config,
|
||||||
device_config, lora_config)
|
device_config, lora_config,
|
||||||
|
vision_language_config)
|
||||||
|
|
||||||
# Ping the tokenizer to ensure liveness if it runs in a
|
# Ping the tokenizer to ensure liveness if it runs in a
|
||||||
# different process.
|
# different process.
|
||||||
@ -170,7 +174,6 @@ class LLMEngine:
|
|||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
revision=self.model_config.tokenizer_revision)
|
revision=self.model_config.tokenizer_revision)
|
||||||
init_kwargs.update(tokenizer_init_kwargs)
|
init_kwargs.update(tokenizer_init_kwargs)
|
||||||
|
|
||||||
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
|
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
|
||||||
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
||||||
|
|
||||||
@ -212,6 +215,7 @@ class LLMEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
@ -228,6 +232,7 @@ class LLMEngine:
|
|||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
arrival_time: The arrival time of the request. If None, we use
|
arrival_time: The arrival time of the request. If None, we use
|
||||||
the current monotonic time.
|
the current monotonic time.
|
||||||
|
multi_modal_data: Multi modal data per request.
|
||||||
|
|
||||||
Details:
|
Details:
|
||||||
- Set arrival_time to the current time if it is None.
|
- Set arrival_time to the current time if it is None.
|
||||||
@ -288,7 +293,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
arrival_time, lora_request)
|
arrival_time, lora_request, multi_modal_data)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
@ -8,6 +9,7 @@ from vllm.engine.llm_engine import LLMEngine
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import MultiModalData
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
|
|
||||||
|
|
||||||
@ -126,6 +128,7 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""Generates the completions for the input prompts.
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
@ -141,6 +144,7 @@ class LLM:
|
|||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
multi_modal_data: Multi modal data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `RequestOutput` objects containing the generated
|
A list of `RequestOutput` objects containing the generated
|
||||||
@ -160,6 +164,9 @@ class LLM:
|
|||||||
# Use default sampling params.
|
# Use default sampling params.
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
if multi_modal_data:
|
||||||
|
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
num_requests = len(prompts) if prompts is not None else len(
|
num_requests = len(prompts) if prompts is not None else len(
|
||||||
prompt_token_ids)
|
prompt_token_ids)
|
||||||
@ -167,10 +174,17 @@ class LLM:
|
|||||||
prompt = prompts[i] if prompts is not None else None
|
prompt = prompts[i] if prompts is not None else None
|
||||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||||
i]
|
i]
|
||||||
self._add_request(prompt,
|
self._add_request(
|
||||||
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
token_ids,
|
token_ids,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request,
|
||||||
|
# Get ith image while maintaining the batch dim.
|
||||||
|
multi_modal_data=MultiModalData(
|
||||||
|
type=multi_modal_data.type,
|
||||||
|
data=multi_modal_data.data[i].unsqueeze(0))
|
||||||
|
if multi_modal_data else None,
|
||||||
|
)
|
||||||
return self._run_engine(use_tqdm)
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
@ -179,13 +193,15 @@ class LLM:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(request_id,
|
self.llm_engine.add_request(request_id,
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
prompt_token_ids,
|
prompt_token_ids,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ class ExecutorBase(ABC):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.executor.utils import check_block_size_valid
|
from vllm.executor.utils import check_block_size_valid
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -23,6 +23,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -30,6 +31,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
# Instantiate the worker and load the model to GPU.
|
# Instantiate the worker and load the model to GPU.
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
@ -56,6 +58,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.executor.utils import check_block_size_valid
|
from vllm.executor.utils import check_block_size_valid
|
||||||
@ -40,6 +40,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -47,6 +48,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
assert self.parallel_config.worker_use_ray
|
assert self.parallel_config.worker_use_ray
|
||||||
placement_group = self.parallel_config.placement_group
|
placement_group = self.parallel_config.placement_group
|
||||||
@ -181,6 +183,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
driver_rank,
|
driver_rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,9 +7,14 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.config import DeviceConfig, ModelConfig
|
from vllm.config import DeviceConfig, ModelConfig
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
initialize_dummy_weights)
|
initialize_dummy_weights)
|
||||||
|
|
||||||
|
_VISION_MODEL_CLASSES = [
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
@ -40,6 +45,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
|
|||||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
**kwargs) -> nn.Module:
|
**kwargs) -> nn.Module:
|
||||||
lora_config = kwargs.get("lora_config", None)
|
lora_config = kwargs.get("lora_config", None)
|
||||||
|
vision_language_config = kwargs.get("vision_language_config", None)
|
||||||
model_class = _get_model_architecture(model_config)
|
model_class = _get_model_architecture(model_config)
|
||||||
|
|
||||||
# Get the (maybe quantized) linear method.
|
# Get the (maybe quantized) linear method.
|
||||||
@ -76,7 +82,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
|||||||
"be added in the future. If this is important to you, "
|
"be added in the future. If this is important to you, "
|
||||||
"please open an issue on github.")
|
"please open an issue on github.")
|
||||||
else:
|
else:
|
||||||
|
if model_class not in _VISION_MODEL_CLASSES:
|
||||||
model = model_class(model_config.hf_config, linear_method)
|
model = model_class(model_config.hf_config, linear_method)
|
||||||
|
else:
|
||||||
|
model = model_class(model_config.hf_config,
|
||||||
|
vision_language_config, linear_method)
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
|
|||||||
@ -29,6 +29,8 @@ _MODELS = {
|
|||||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
"LlavaForConditionalGeneration":
|
||||||
|
("llava", "LlavaForConditionalGeneration"),
|
||||||
# For decapoda-research/llama-*
|
# For decapoda-research/llama-*
|
||||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
|
|||||||
@ -250,14 +250,21 @@ class LlamaModel(nn.Module):
|
|||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
|
|||||||
246
vllm/model_executor/models/llava.py
Normal file
246
vllm/model_executor/models/llava.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||||
|
# transformers' impl.
|
||||||
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import VisionLanguageConfig
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
|
hf_model_weights_iterator)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
|
"language_model.lm_head": "lm_head",
|
||||||
|
"language_model.model": "language_model",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||||
|
class LlavaMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
|
||||||
|
projector_hidden_act: str):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(vision_hidden_size,
|
||||||
|
text_hidden_size,
|
||||||
|
bias=True)
|
||||||
|
self.act = get_act_fn(projector_hidden_act)
|
||||||
|
self.linear_2 = nn.Linear(text_hidden_size,
|
||||||
|
text_hidden_size,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def forward(self, image_features):
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
vision_embeddings: torch.Tensor,
|
||||||
|
image_token_id: int):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
mask = (input_ids == image_token_id)
|
||||||
|
inputs_embeds[mask] = vision_embeddings.view(-1,
|
||||||
|
vision_embeddings.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaForConditionalGeneration(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: "LlavaConfig",
|
||||||
|
vision_language_config: VisionLanguageConfig,
|
||||||
|
linear_method: Optional["LinearMethodBase"] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
|
assert self.vision_language_config, (
|
||||||
|
"Provide `image_input_type` and other vision "
|
||||||
|
"related configurations through LLM entrypoint "
|
||||||
|
"or engine arguments.")
|
||||||
|
|
||||||
|
if self.vision_language_config.image_input_type == (
|
||||||
|
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||||
|
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||||
|
else:
|
||||||
|
self.vision_tower = None
|
||||||
|
|
||||||
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||||
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
|
text_hidden_size=config.text_config.hidden_size,
|
||||||
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.language_model = LlamaModel(config.text_config, linear_method)
|
||||||
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.text_config.hidden_size,
|
||||||
|
org_num_embeddings=self.language_model.org_vocab_size)
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size, logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
image_input: Optional[torch.Tensor] = None
|
||||||
|
) -> SamplerOutput: # noqa: E501
|
||||||
|
"""Run forward pass for Llava 1.5.
|
||||||
|
|
||||||
|
One key thing to understand is the `input_ids` already accounts for the
|
||||||
|
positions of the to-be-inserted image embeddings.
|
||||||
|
Concretely, consider a text prompt:
|
||||||
|
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
|
||||||
|
Tokenizer outputs:
|
||||||
|
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
|
||||||
|
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
|
||||||
|
The to-be-inserted image has a size of 576 (24 * 24) along the context
|
||||||
|
length dimension.
|
||||||
|
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
|
||||||
|
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
|
||||||
|
9047, 13566, 29901].
|
||||||
|
There will be 576 `32000` in the `input_ids`.
|
||||||
|
(32000 is the token id for `<image>`.)
|
||||||
|
|
||||||
|
This way, the `positions` and `attn_metadata` are consistent
|
||||||
|
with the `input_ids`.
|
||||||
|
|
||||||
|
The model takes two types of image inputs:
|
||||||
|
PIXEL_VALUES and IMAGE_FEATURES.
|
||||||
|
The following shows how each maps to huggingface implementation.
|
||||||
|
PIXEL_VALUES:
|
||||||
|
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
|
||||||
|
IMAGE_FEATURES:
|
||||||
|
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
|
||||||
|
before going through the multi modal projector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
|
batch.
|
||||||
|
image_input: A batch of image inputs.
|
||||||
|
For PIXEL_VALUES, expecting [1, 3, 336, 336].
|
||||||
|
For IMAGE_FEATURES, expecting [1, 576, 1024].
|
||||||
|
"""
|
||||||
|
if image_input is not None:
|
||||||
|
if list(image_input.shape[1:]) != list(
|
||||||
|
self.vision_language_config.image_input_shape[1:]):
|
||||||
|
raise ValueError(
|
||||||
|
f"The expected image tensor shape is batch dimension "
|
||||||
|
f"plus "
|
||||||
|
f"{self.vision_language_config.image_input_shape[1:]}."
|
||||||
|
f" You supplied {image_input.shape}. "
|
||||||
|
f"If you are using vLLM's entrypoint, make sure your "
|
||||||
|
f"supplied image input is consistent with "
|
||||||
|
f"image_input_shape in engine args.")
|
||||||
|
if self.vision_tower is not None:
|
||||||
|
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
|
||||||
|
image_outputs = self.vision_tower(image_input,
|
||||||
|
output_hidden_states=True)
|
||||||
|
image_features = image_outputs.hidden_states[
|
||||||
|
self.config.vision_feature_layer]
|
||||||
|
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
||||||
|
if self.config.vision_feature_select_strategy == "default":
|
||||||
|
image_features = image_features[:, 1:]
|
||||||
|
elif self.config.vision_feature_select_strategy == "full":
|
||||||
|
image_features = image_features
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature strategy: "
|
||||||
|
f"{self.config.vision_feature_select_strategy}")
|
||||||
|
else:
|
||||||
|
image_features = image_input
|
||||||
|
vision_embeddings = self.multi_modal_projector(image_features)
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
_merge_vision_embeddings(
|
||||||
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
|
self.vision_language_config.image_token_id)
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
hidden_states = self.language_model(input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
# only doing this for language model part for now.
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||||
|
if key_to_modify in name:
|
||||||
|
name = name.replace(key_to_modify, new_key)
|
||||||
|
use_default_weight_loading = False
|
||||||
|
if "vision" in name:
|
||||||
|
if self.vision_tower is not None:
|
||||||
|
# We only do sharding for language model and
|
||||||
|
# not vision model for now.
|
||||||
|
use_default_weight_loading = True
|
||||||
|
else:
|
||||||
|
for (param_name, weight_name,
|
||||||
|
shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = params_dict[name.replace(weight_name, param_name)]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
use_default_weight_loading = True
|
||||||
|
if use_default_weight_loading:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
@ -303,6 +303,25 @@ class SequenceGroupState:
|
|||||||
generator: Optional = None
|
generator: Optional = None
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalData:
|
||||||
|
"""Multi modal request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: The data type.
|
||||||
|
data: The actual data.
|
||||||
|
The required shape and semantic meaning of it depends on the vision
|
||||||
|
language config of the hosted model.
|
||||||
|
See `VisionLanguageConfig` in `config.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Type(enum.Enum):
|
||||||
|
IMAGE = enum.auto()
|
||||||
|
|
||||||
|
def __init__(self, type: Type, data: "torch.Tensor"):
|
||||||
|
self.type = type
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroup:
|
class SequenceGroup:
|
||||||
"""A group of sequences that are generated from the same prompt.
|
"""A group of sequences that are generated from the same prompt.
|
||||||
|
|
||||||
@ -312,6 +331,7 @@ class SequenceGroup:
|
|||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
arrival_time: The arrival time of the request.
|
arrival_time: The arrival time of the request.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
|
multi_modal_data: Multi modal data associated with the request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -321,6 +341,7 @@ class SequenceGroup:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
@ -333,6 +354,7 @@ class SequenceGroup:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
self.state = SequenceGroupState()
|
self.state = SequenceGroupState()
|
||||||
|
self.multi_modal_data = multi_modal_data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> str:
|
def prompt(self) -> str:
|
||||||
@ -450,6 +472,7 @@ class SequenceGroupMetadata:
|
|||||||
numbers)
|
numbers)
|
||||||
state: Internal state tied to this sequence group.
|
state: Internal state tied to this sequence group.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
|
multi_modal_data: Multi modal data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -462,6 +485,7 @@ class SequenceGroupMetadata:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
computed_block_nums: Optional[List[int]] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
state: Optional[SequenceGroupState] = None,
|
state: Optional[SequenceGroupState] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
@ -470,6 +494,7 @@ class SequenceGroupMetadata:
|
|||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.computed_block_nums = computed_block_nums
|
self.computed_block_nums = computed_block_nums
|
||||||
|
self.multi_modal_data = multi_modal_data
|
||||||
self.state = SequenceGroupState() if state is None else state
|
self.state = SequenceGroupState() if state is None else state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -40,3 +40,17 @@ def get_config(model: str,
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision)
|
code_revision=code_revision)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
|
No op for pure text models.
|
||||||
|
"""
|
||||||
|
if hasattr(config, "text_config"):
|
||||||
|
# The code operates under the assumption that text_config should have
|
||||||
|
# `num_attention_heads` (among others). Assert here to fail early
|
||||||
|
# if transformers config doesn't align with this assumption.
|
||||||
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
|
return config.text_config
|
||||||
|
else:
|
||||||
|
return config
|
||||||
|
|||||||
@ -377,6 +377,16 @@ class CudaMemoryProfiler:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_int_tuple(s: str) -> Tuple[int]:
|
||||||
|
"""Convert a string to a tuple of integers."""
|
||||||
|
try:
|
||||||
|
return tuple(map(int, s.split(",")))
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"String must be a series of integers separated by commas "
|
||||||
|
f"(e.g., 1, 2, 3). Given input: {s}") from e
|
||||||
|
|
||||||
|
|
||||||
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
assert len(x) <= max_len
|
assert len(x) <= max_len
|
||||||
return x + [pad] * (max_len - len(x))
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -21,7 +21,8 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
|||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
with_cupy_nccl_for_all_reduce)
|
with_cupy_nccl_for_all_reduce)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
||||||
|
SequenceGroupMetadata)
|
||||||
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
|
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
|
||||||
is_pin_memory_available, make_tensor_with_pad,
|
is_pin_memory_available, make_tensor_with_pad,
|
||||||
maybe_expand_dim)
|
maybe_expand_dim)
|
||||||
@ -49,6 +50,7 @@ class ModelRunner:
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -83,15 +85,18 @@ class ModelRunner:
|
|||||||
self.graph_block_tables = None # Set after initial profiling.
|
self.graph_block_tables = None # Set after initial profiling.
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.dtype if model_config is not None else None)
|
self.model_config.dtype if model_config is not None else None)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
with CudaMemoryProfiler() as m:
|
with CudaMemoryProfiler() as m:
|
||||||
self.model = get_model(self.model_config,
|
self.model = get_model(
|
||||||
|
self.model_config,
|
||||||
self.device_config,
|
self.device_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=self.scheduler_config)
|
scheduler_config=self.scheduler_config)
|
||||||
|
|
||||||
@ -130,7 +135,8 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||||
List[int], List[int], List[int], Set[LoRARequest]]:
|
List[int], List[int], List[int], Set[LoRARequest],
|
||||||
|
torch.Tensor]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
@ -143,6 +149,7 @@ class ModelRunner:
|
|||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
subquery_lens: List[int] = []
|
subquery_lens: List[int] = []
|
||||||
prefix_block_tables: List[List[int]] = []
|
prefix_block_tables: List[List[int]] = []
|
||||||
|
multi_modal_input_list: List[torch.Tensor] = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert seq_group_metadata.is_prompt
|
assert seq_group_metadata.is_prompt
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
@ -188,6 +195,10 @@ class ModelRunner:
|
|||||||
(prompt_len - computed_len
|
(prompt_len - computed_len
|
||||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||||
|
|
||||||
|
if seq_group_metadata.multi_modal_data:
|
||||||
|
multi_modal_input_list.append(
|
||||||
|
seq_group_metadata.multi_modal_data.data)
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
if seq_group_metadata.block_tables is None:
|
||||||
# During memory profiling, the block tables are not initialized
|
# During memory profiling, the block tables are not initialized
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
@ -236,6 +247,16 @@ class ModelRunner:
|
|||||||
context_lens_tensor = torch.tensor(context_lens,
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
|
if multi_modal_input_list:
|
||||||
|
assert self.vision_language_config, (
|
||||||
|
"Multi-modal inputs are only supported by "
|
||||||
|
"vision language models.")
|
||||||
|
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||||
|
dim=0).to(self.device)
|
||||||
|
else:
|
||||||
|
multi_modal_input = None
|
||||||
|
|
||||||
# Prepare prefix block tables
|
# Prepare prefix block tables
|
||||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
||||||
block_tables = make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
@ -291,7 +312,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||||
lora_requests)
|
lora_requests, multi_modal_input)
|
||||||
|
|
||||||
def _prepare_decode(
|
def _prepare_decode(
|
||||||
self,
|
self,
|
||||||
@ -525,7 +546,7 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
Set[int], LoRAMapping]:
|
Set[int], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
@ -534,13 +555,15 @@ class ModelRunner:
|
|||||||
if is_prompt:
|
if is_prompt:
|
||||||
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||||
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
|
lora_requests, multi_modal_input
|
||||||
|
) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
else:
|
else:
|
||||||
(input_tokens, input_positions, attn_metadata,
|
(input_tokens, input_positions, attn_metadata,
|
||||||
lora_index_mapping, lora_prompt_mapping,
|
lora_index_mapping, lora_prompt_mapping,
|
||||||
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
subquery_lens = None
|
subquery_lens = None
|
||||||
|
multi_modal_input = None
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||||
prompt_lens,
|
prompt_lens,
|
||||||
subquery_lens)
|
subquery_lens)
|
||||||
@ -561,6 +584,7 @@ class ModelRunner:
|
|||||||
sampling_metadata.selected_token_indices,
|
sampling_metadata.selected_token_indices,
|
||||||
"lora_requests": lora_requests,
|
"lora_requests": lora_requests,
|
||||||
"lora_mapping": lora_mapping,
|
"lora_mapping": lora_mapping,
|
||||||
|
"multi_modal_input": multi_modal_input,
|
||||||
}
|
}
|
||||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
@ -572,6 +596,7 @@ class ModelRunner:
|
|||||||
"selected_token_indices")
|
"selected_token_indices")
|
||||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||||
lora_requests = metadata_dict.pop("lora_requests")
|
lora_requests = metadata_dict.pop("lora_requests")
|
||||||
|
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
seq_groups=None,
|
seq_groups=None,
|
||||||
@ -584,7 +609,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
return (input_tokens, input_positions, attn_metadata,
|
||||||
sampling_metadata, lora_requests, lora_mapping)
|
sampling_metadata, lora_requests, lora_mapping,
|
||||||
|
multi_modal_input)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -593,8 +619,8 @@ class ModelRunner:
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||||
lora_requests,
|
lora_requests, lora_mapping, multi_modal_input
|
||||||
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
|
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(lora_requests, lora_mapping)
|
self.set_active_loras(lora_requests, lora_mapping)
|
||||||
@ -605,12 +631,15 @@ class ModelRunner:
|
|||||||
model_executable = self.graph_runners[graph_batch_size]
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
hidden_states = model_executable(
|
execute_model_kwargs = {
|
||||||
input_ids=input_tokens,
|
"input_ids": input_tokens,
|
||||||
positions=input_positions,
|
"positions": input_positions,
|
||||||
kv_caches=kv_caches,
|
"kv_caches": kv_caches,
|
||||||
attn_metadata=attn_metadata,
|
"attn_metadata": attn_metadata,
|
||||||
)
|
}
|
||||||
|
if self.vision_language_config:
|
||||||
|
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
@ -658,10 +687,22 @@ class ModelRunner:
|
|||||||
# Profile memory usage with max_num_sequences sequences and the total
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
# number of tokens equal to max_num_batched_tokens.
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
seqs: List[SequenceGroupMetadata] = []
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
# Additional GPU memory may be needed for vision encoding, which needs
|
||||||
|
# to be accounted for when calculating the GPU blocks for
|
||||||
|
# vLLM blocker manager.
|
||||||
|
# To exercise the worst scenario for GPU memory consumption,
|
||||||
|
# the number of seqs (batch_size) is chosen to maximize the number
|
||||||
|
# of images processed.
|
||||||
|
if self.vision_language_config:
|
||||||
|
max_num_seqs = min(
|
||||||
|
max_num_seqs,
|
||||||
|
int(max_num_batched_tokens /
|
||||||
|
self.vision_language_config.image_feature_size))
|
||||||
for group_id in range(max_num_seqs):
|
for group_id in range(max_num_seqs):
|
||||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
seq_data = SequenceData([0] * seq_len)
|
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
|
||||||
|
seq_len, self.vision_language_config)
|
||||||
seq = SequenceGroupMetadata(
|
seq = SequenceGroupMetadata(
|
||||||
request_id=str(group_id),
|
request_id=str(group_id),
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -670,6 +711,7 @@ class ModelRunner:
|
|||||||
block_tables=None,
|
block_tables=None,
|
||||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||||
if dummy_lora_requests_per_seq else None,
|
if dummy_lora_requests_per_seq else None,
|
||||||
|
multi_modal_data=fake_multi_modal_input,
|
||||||
)
|
)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
@ -831,6 +873,7 @@ class CUDAGraphRunner:
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
memory_pool,
|
memory_pool,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.graph is None
|
assert self.graph is None
|
||||||
# Run the model once without capturing the graph.
|
# Run the model once without capturing the graph.
|
||||||
@ -842,6 +885,7 @@ class CUDAGraphRunner:
|
|||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -856,6 +900,7 @@ class CUDAGraphRunner:
|
|||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -877,6 +922,7 @@ class CUDAGraphRunner:
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# KV caches are fixed tensors, so we don't need to copy them.
|
# KV caches are fixed tensors, so we don't need to copy them.
|
||||||
del kv_caches
|
del kv_caches
|
||||||
@ -922,3 +968,21 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|||||||
else:
|
else:
|
||||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_fake_inputs(
|
||||||
|
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
|
||||||
|
"""Prepare fake inputs for profile run."""
|
||||||
|
if vision_language_config:
|
||||||
|
prompt_tokens = [
|
||||||
|
vision_language_config.image_token_id
|
||||||
|
] * vision_language_config.image_feature_size + [0] * (
|
||||||
|
seq_len - vision_language_config.image_feature_size)
|
||||||
|
fake_image_input = MultiModalData(
|
||||||
|
type=MultiModalData.Type.IMAGE,
|
||||||
|
data=torch.zeros(vision_language_config.image_input_shape,
|
||||||
|
dtype=torch.float16))
|
||||||
|
else:
|
||||||
|
prompt_tokens = [0] * seq_len
|
||||||
|
fake_image_input = None
|
||||||
|
return SequenceData(prompt_tokens), fake_image_input
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.parallel_utils import cupy_utils
|
from vllm.model_executor.parallel_utils import cupy_utils
|
||||||
@ -39,6 +39,7 @@ class Worker:
|
|||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -54,13 +55,20 @@ class Worker:
|
|||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
self.model_runner = ModelRunner(model_config,
|
self.vision_language_config = vision_language_config
|
||||||
|
if self.vision_language_config:
|
||||||
|
assert not self.lora_config, (
|
||||||
|
"To be tested: vision language model with LoRA settings.")
|
||||||
|
|
||||||
|
self.model_runner = ModelRunner(
|
||||||
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
device_config,
|
device_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker,
|
||||||
|
vision_language_config=vision_language_config)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
self.cache_config = None
|
self.cache_config = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user