mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[VLM] Separate text-only and vision variants of the same model architecture (#13157)
This commit is contained in:
parent
02ed8a1fbe
commit
1bc3b5e71b
@ -699,10 +699,10 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `DeepseekVLV2ForCausalLM`
|
||||
- * `DeepseekVLV2ForCausalLM`<sup>^</sup>
|
||||
* DeepSeek-VL2
|
||||
* T + I<sup>+</sup>
|
||||
* `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. (see note)
|
||||
* `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
@ -713,10 +713,10 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `ChatGLMModel`
|
||||
- * `GLM4VForCausalLM`<sup>^</sup>
|
||||
* GLM-4V
|
||||
* T + I
|
||||
* `THUDM/glm-4v-9b` etc.
|
||||
* `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
@ -825,7 +825,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `QWenLMHeadModel`
|
||||
- * `QwenVLForConditionalGeneration`<sup>^</sup>
|
||||
* Qwen-VL
|
||||
* T + I<sup>E+</sup>
|
||||
* `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc.
|
||||
@ -862,13 +862,12 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
:::
|
||||
|
||||
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
|
||||
• For example, to use DeepSeek-VL2 series models:
|
||||
`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
|
||||
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
:::{note}
|
||||
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
|
||||
:::
|
||||
|
||||
@ -105,7 +105,9 @@ def run_glm4v(question: str, modality: str):
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>"
|
||||
|
||||
@ -495,6 +497,7 @@ def run_qwen_vl(question: str, modality: str):
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def load_deepseek_vl2(question: str, image_urls: List[str]):
|
||||
)
|
||||
|
||||
|
||||
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "h2oai/h2ovl-mississippi-2b"
|
||||
|
||||
llm = LLM(
|
||||
@ -302,6 +302,7 @@ def load_qwen_vl_chat(question: str,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
||||
@ -452,7 +453,7 @@ def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"h2ovl_chat": load_h2onvl,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"idefics3": load_idefics3,
|
||||
"internvl_chat": load_internvl,
|
||||
"mllama": load_mllama,
|
||||
|
||||
@ -6,6 +6,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
@ -15,6 +16,7 @@ import pytest
|
||||
from vllm.config import TaskOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_pipeline_parallel")
|
||||
@ -31,10 +33,7 @@ class ParallelSetup(NamedTuple):
|
||||
|
||||
class PPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
trust_remote_code: bool
|
||||
tokenizer_mode: Optional[str]
|
||||
load_format: Optional[str] = None
|
||||
hf_overrides: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -64,10 +63,7 @@ class PPTestSettings:
|
||||
pp_base: int = 2,
|
||||
multi_node_only: bool = False,
|
||||
task: TaskOption = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
load_format: Optional[str] = None,
|
||||
hf_overrides: Optional[str] = None,
|
||||
):
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
@ -97,10 +93,7 @@ class PPTestSettings:
|
||||
vllm_major_versions=["0", "0", "1"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -110,10 +103,7 @@ class PPTestSettings:
|
||||
pp_base: int = 2,
|
||||
task: TaskOption = "auto",
|
||||
multi_node_only: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_mode: Optional[str] = None,
|
||||
load_format: Optional[str] = None,
|
||||
hf_overrides: Optional[str] = None,
|
||||
):
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
@ -126,19 +116,16 @@ class PPTestSettings:
|
||||
vllm_major_versions=["0"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
load_format=load_format,
|
||||
hf_overrides=hf_overrides),
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
def iter_params(self, model_name: str):
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_name, parallel_setup, backend, vllm_major_version,
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.task, opts)
|
||||
|
||||
|
||||
@ -150,16 +137,16 @@ TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
# Uses Llama
|
||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
|
||||
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501
|
||||
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
|
||||
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
|
||||
"bigscience/bloomz-1b1": PPTestSettings.fast(),
|
||||
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
|
||||
"databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
|
||||
"THUDM/chatglm3-6b": PPTestSettings.fast(),
|
||||
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
|
||||
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
|
||||
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
|
||||
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
|
||||
"tiiuae/falcon-7b": PPTestSettings.fast(),
|
||||
"google/gemma-2b": PPTestSettings.fast(),
|
||||
@ -172,36 +159,36 @@ TEXT_GENERATION_MODELS = {
|
||||
"ibm/PowerMoE-3b": PPTestSettings.fast(),
|
||||
# Uses Llama
|
||||
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
|
||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"internlm/internlm2-chat-7b": PPTestSettings.fast(),
|
||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
||||
# Uses Llama
|
||||
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
|
||||
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), # noqa: E501
|
||||
"mosaicml/mpt-7b": PPTestSettings.fast(),
|
||||
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
|
||||
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
|
||||
"shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
|
||||
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
|
||||
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
|
||||
"adept/persimmon-8b-chat": PPTestSettings.fast(),
|
||||
"microsoft/phi-2": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'), # noqa: E501
|
||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501
|
||||
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
|
||||
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
|
||||
"bigcode/starcoder2-3b": PPTestSettings.fast(),
|
||||
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
|
||||
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501
|
||||
# FIXME: Cannot load tokenizer in latest transformers version.
|
||||
# Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
|
||||
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
|
||||
# [Encoder-only]
|
||||
# TODO: Implement PP
|
||||
# "facebook/bart-base": PPTestSettings.fast(),
|
||||
@ -211,7 +198,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
||||
# [Text-only]
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
|
||||
}
|
||||
|
||||
MULTIMODAL_MODELS = {
|
||||
@ -219,20 +206,20 @@ MULTIMODAL_MODELS = {
|
||||
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
|
||||
"facebook/chameleon-7b": PPTestSettings.fast(),
|
||||
"adept/fuyu-8b": PPTestSettings.fast(),
|
||||
"THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"THUDM/glm-4v-9b": PPTestSettings.fast(),
|
||||
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
|
||||
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True),
|
||||
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(),
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
|
||||
# [Encoder-decoder]
|
||||
# TODO: Implement PP
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
@ -258,7 +245,7 @@ TEST_MODELS = [
|
||||
|
||||
|
||||
def _compare_tp(
|
||||
model_name: str,
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
@ -267,6 +254,7 @@ def _compare_tp(
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate", "encode"],
|
||||
is_multimodal: bool,
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
@ -274,13 +262,32 @@ def _compare_tp(
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
(
|
||||
multi_node_only,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
load_format,
|
||||
hf_overrides,
|
||||
) = test_options
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
text_overrides = {
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"num_experts": 2,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": 2,
|
||||
}
|
||||
|
||||
if is_multimodal:
|
||||
hf_overrides.update({"text_config": text_overrides})
|
||||
else:
|
||||
hf_overrides.update(text_overrides)
|
||||
else:
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
@ -312,7 +319,7 @@ def _compare_tp(
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", hf_overrides])
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||
if distributed_backend == "ray" and (vllm_major_version == "1"
|
||||
@ -355,11 +362,7 @@ def _compare_tp(
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_name,
|
||||
pp_args,
|
||||
tp_args,
|
||||
pp_env,
|
||||
method=method)
|
||||
compare_two_settings(model_id, pp_args, tp_args, pp_env, method=method)
|
||||
except Exception:
|
||||
if pp_env is None:
|
||||
raise
|
||||
@ -369,17 +372,16 @@ def _compare_tp(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
params for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_language_generation(
|
||||
model_name: str,
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
@ -387,28 +389,28 @@ def test_tp_language_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
params for model_id, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_language_embedding(
|
||||
model_name: str,
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
@ -416,28 +418,28 @@ def test_tp_language_embedding(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="encode")
|
||||
method="encode",
|
||||
is_multimodal=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
params for model_id, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_multimodal_generation(
|
||||
model_name: str,
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
@ -445,11 +447,12 @@ def test_tp_multimodal_generation(
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
method="generate",
|
||||
is_multimodal=True)
|
||||
|
||||
@ -155,10 +155,7 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
marks=[pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION < "4.49.0",
|
||||
reason="HF model requires transformers>=4.49.0",
|
||||
), pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
"aria": VLMTestInfo(
|
||||
@ -215,7 +212,6 @@ VLM_TEST_SETTINGS = {
|
||||
"cherry_blossom": "<image>\nPlease infer the season with reason in details.", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501
|
||||
vllm_runner_kwargs={"hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]}}, # noqa: E501
|
||||
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
|
||||
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
|
||||
@ -240,7 +236,7 @@ VLM_TEST_SETTINGS = {
|
||||
num_logprobs=10,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
),
|
||||
"glm4": VLMTestInfo(
|
||||
"glm4v": VLMTestInfo(
|
||||
models=["THUDM/glm-4v-9b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
prompt_formatter=identity,
|
||||
@ -351,7 +347,6 @@ VLM_TEST_SETTINGS = {
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
|
||||
get_stop_token_ids=lambda tok: [128009],
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
|
||||
@ -437,7 +432,7 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
),
|
||||
"qwen": VLMTestInfo(
|
||||
"qwen_vl": VLMTestInfo(
|
||||
models=["Qwen/Qwen-VL"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=identity,
|
||||
|
||||
@ -4,12 +4,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
|
||||
from transformers import BatchEncoding
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .....conftest import HfRunner, VllmRunner
|
||||
from ....registry import HF_EXAMPLE_MODELS
|
||||
from .types import RunnerOutput
|
||||
|
||||
|
||||
@ -31,10 +33,8 @@ def run_test(
|
||||
use_tokenizer_eos: bool,
|
||||
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
||||
comparator: Callable[..., None],
|
||||
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
|
||||
List[int]]],
|
||||
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]],
|
||||
stop_str: Optional[List[str]],
|
||||
tokenizer_mode: str,
|
||||
limit_mm_per_prompt: Dict[str, int],
|
||||
vllm_runner_kwargs: Optional[Dict[str, Any]],
|
||||
hf_model_kwargs: Optional[Dict[str, Any]],
|
||||
@ -48,7 +48,10 @@ def run_test(
|
||||
"""Modality agnostic test test executor for comparing HF/vLLM outputs."""
|
||||
# In the case of embeddings, vLLM takes separate input tensors
|
||||
vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
vllm_outputs_per_mm = []
|
||||
hf_outputs_per_mm = []
|
||||
@ -57,17 +60,19 @@ def run_test(
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
vllm_kwargs: Dict[str, Any] = {}
|
||||
if get_stop_token_ids is not None:
|
||||
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
|
||||
if stop_str:
|
||||
vllm_kwargs["stop"] = stop_str
|
||||
|
||||
if vllm_runner_kwargs is None:
|
||||
vllm_runner_kwargs = {}
|
||||
vllm_runner_kwargs_: Dict[str, Any] = {}
|
||||
if model_info.tokenizer:
|
||||
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
|
||||
if model_info.tokenizer_mode:
|
||||
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
||||
if model_info.hf_overrides:
|
||||
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
|
||||
|
||||
if vllm_runner_kwargs:
|
||||
vllm_runner_kwargs_.update(vllm_runner_kwargs)
|
||||
|
||||
with vllm_runner(model,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=max_num_seqs,
|
||||
dtype=dtype,
|
||||
@ -76,7 +81,15 @@ def run_test(
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=enforce_eager,
|
||||
task=task,
|
||||
**vllm_runner_kwargs) as vllm_model:
|
||||
**vllm_runner_kwargs_) as vllm_model:
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
|
||||
vllm_kwargs: Dict[str, Any] = {}
|
||||
if get_stop_token_ids is not None:
|
||||
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
|
||||
if stop_str:
|
||||
vllm_kwargs["stop"] = stop_str
|
||||
|
||||
for prompts, media in vllm_inputs:
|
||||
vllm_kwargs[runner_mm_key] = media
|
||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||
@ -93,16 +106,19 @@ def run_test(
|
||||
if patch_hf_runner is not None:
|
||||
hf_model = patch_hf_runner(hf_model)
|
||||
|
||||
# Some models need to explicitly pass the eos_token_id off the tokenizer or
|
||||
# processor for a good comparison; currently assume processor/tokenizer
|
||||
# agree on the EOS, and pull it off the tokenizer if requested.
|
||||
hf_kwargs = {}
|
||||
if use_tokenizer_eos:
|
||||
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
||||
if stop_str:
|
||||
hf_kwargs["stop_strings"] = stop_str
|
||||
|
||||
with hf_model, torch.no_grad():
|
||||
tokenizer = hf_model.tokenizer
|
||||
|
||||
# Some models need to explicitly pass the eos_token_id off the tokenizer
|
||||
# or processor for a good comparison;
|
||||
# currently assume processor/tokenizer agree on the EOS, and pull it off
|
||||
# the tokenizer if requested.
|
||||
hf_kwargs = {}
|
||||
if use_tokenizer_eos:
|
||||
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
|
||||
if stop_str:
|
||||
hf_kwargs["stop_strings"] = stop_str
|
||||
|
||||
for prompts, media in inputs:
|
||||
hf_kwargs[runner_mm_key] = media
|
||||
hf_output = hf_model.generate_greedy_logprobs_limit(
|
||||
|
||||
@ -8,12 +8,12 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pytest import MarkDecorator
|
||||
from transformers import (AutoModelForCausalLM, BatchEncoding,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import AutoModelForCausalLM, BatchEncoding
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import identity
|
||||
|
||||
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
|
||||
@ -100,8 +100,7 @@ class VLMTestInfo(NamedTuple):
|
||||
vllm_runner_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Optional callable which gets a list of token IDs from the model tokenizer
|
||||
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
|
||||
List[int]]] = None
|
||||
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]] = None
|
||||
# Optional list of strings to stop generation, useful when stop tokens are
|
||||
# not special tokens in the tokenizer
|
||||
stop_str: Optional[List[str]] = None
|
||||
@ -156,8 +155,6 @@ class VLMTestInfo(NamedTuple):
|
||||
|
||||
marks: Optional[List[MarkDecorator]] = None
|
||||
|
||||
tokenizer_mode: str = "auto"
|
||||
|
||||
def get_non_parametrized_runner_kwargs(self):
|
||||
"""Returns a dictionary of expandable kwargs for items that are used
|
||||
in all test types, which are NOT used when creating the parametrized
|
||||
@ -180,7 +177,6 @@ class VLMTestInfo(NamedTuple):
|
||||
"hf_model_kwargs": self.hf_model_kwargs,
|
||||
"stop_str": self.stop_str,
|
||||
"patch_hf_runner": self.patch_hf_runner,
|
||||
"tokenizer_mode": self.tokenizer_mode
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -104,7 +104,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B"),
|
||||
"BloomForCausalLM": _HfExamplesInfo("bigscience/bloomz-1b1"),
|
||||
# ChatGLMModel supports multimodal
|
||||
"ChatGLMModel": _HfExamplesInfo("THUDM/chatglm3-6b",
|
||||
trust_remote_code=True),
|
||||
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
|
||||
trust_remote_code=True),
|
||||
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
|
||||
@ -138,7 +139,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct",
|
||||
trust_remote_code=True),
|
||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini"),
|
||||
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
|
||||
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
|
||||
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B"),
|
||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||
is_available_online=False),
|
||||
@ -167,7 +169,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
# QWenLMHeadModel supports multimodal
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||
trust_remote_code=True),
|
||||
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"),
|
||||
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
|
||||
@ -232,18 +235,19 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
extras={"text_only": "THUDM/chatglm3-6b"},
|
||||
trust_remote_code=True),
|
||||
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
|
||||
is_available_online=False),
|
||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
|
||||
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3"), # noqa: E501
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
extras={"mistral": "mistral-community/pixtral-12b"}), # noqa: E501
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
|
||||
@ -253,21 +257,24 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
|
||||
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
|
||||
trust_remote_code=True),
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
|
||||
trust_remote_code=True),
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-pt-224"), # noqa: E501
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
trust_remote_code=True),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||
tokenizer_mode="mistral"),
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-VL-Chat",
|
||||
extras={"text_only": "Qwen/Qwen-7B-Chat"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
|
||||
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501
|
||||
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
|
||||
@ -18,8 +18,7 @@ def test_can_initialize(model_arch):
|
||||
|
||||
# Avoid OOM
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_vl_v2":
|
||||
hf_config.update({"architectures": ["DeepseekVLV2ForCausalLM"]})
|
||||
hf_config.update(model_info.hf_overrides)
|
||||
|
||||
if hasattr(hf_config, "text_config"):
|
||||
text_config: PretrainedConfig = hf_config.text_config
|
||||
|
||||
@ -1,20 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@ -31,204 +23,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BatchFeature,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class GLMImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class GLM4VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if vision_config := getattr(config, "vision_config", None):
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
else:
|
||||
self.image_transform = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
text_inputs = self.tokenizer(text)
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
if self.image_transform is None:
|
||||
raise ValueError("This model does not support image inputs")
|
||||
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
return tokenizer
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
def get_hf_processor(self) -> GLM4VProcessor:
|
||||
return GLM4VProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_feature_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||
return 0
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
def get_num_image_feature_tokens(self) -> int:
|
||||
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
|
||||
return self.get_num_image_tokens() + 2
|
||||
|
||||
|
||||
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||
return ProcessorInputs(prompt_text="", mm_data={})
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=base_text * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not hasattr(hf_config, "vision_config"):
|
||||
return []
|
||||
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
image_token_id = hf_config.pad_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [boi_token_id] + image_tokens + [eoi_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[boi_token_id, image_token_id, eoi_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
@ -489,7 +291,7 @@ class GLMTransformer(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
@ -498,8 +300,12 @@ class GLMTransformer(nn.Module):
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
# Final layer norm.
|
||||
if get_pp_group().is_last_rank and self.post_layer_norm:
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@ -534,61 +340,11 @@ class ChatGLMModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output_layer")
|
||||
|
||||
vision_config_flag = getattr(config, 'vision_config', None)
|
||||
if vision_config_flag is not None:
|
||||
self.vision_config = Namespace(**config.vision_config)
|
||||
self.vision = EVA2CLIPModel(self.config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision")
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.encoder.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> GLMImagePixelInputs:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is not None and self.vision is not None:
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
if pixel_values.ndim > 2:
|
||||
pixel_values = torch.concat(list(pixel_values))
|
||||
elif isinstance(pixel_values, list):
|
||||
return torch.concat(pixel_values)
|
||||
else:
|
||||
raise TypeError("""pixel_values must be a torch.Tensor
|
||||
or a list of torch.Tensor
|
||||
""")
|
||||
return GLMImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input["pixel_values"] is None:
|
||||
return None
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=self.config.torch_dtype)
|
||||
vision_embeddings = self.vision(pixel_values)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
return inputs_embeds
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -599,26 +355,24 @@ class ChatGLMModel(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
@ -660,12 +414,18 @@ class ChatGLMModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class ChatGLMBaseModel(nn.Module):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".word_embeddings": ""}, )
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
transformer_type: type[ChatGLMModel] = ChatGLMModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@ -678,27 +438,17 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.transformer = transformer_type(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.transformer.output_layer.weight = (
|
||||
self.transformer.embedding.weight)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
**kwargs)
|
||||
return hidden_states
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@ -722,7 +472,7 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
class ChatGLM(ChatGLMBaseModel):
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@ -738,82 +488,28 @@ class ChatGLM(ChatGLMBaseModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if hasattr(config, "vision_config"):
|
||||
hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
|
||||
raise RuntimeError(
|
||||
"The configuration of this model indicates that it supports "
|
||||
"vision inputs, but you instantiated the text-only version "
|
||||
"of this model. Please use the vision model by setting "
|
||||
f"`--hf-overrides {hf_overrides!r}`")
|
||||
|
||||
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
# vision
|
||||
"fc1",
|
||||
"fc2",
|
||||
"merged_proj",
|
||||
"linear_proj"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.encoder",
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
return self.transformer.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "vision_config"): # noqa: SIM108
|
||||
instance_cls = ChatGLMV
|
||||
# Initialize LLM
|
||||
else:
|
||||
instance_cls = ChatGLM
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
@ -1,312 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/GLM-4
|
||||
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(config.in_channels,
|
||||
config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size)
|
||||
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
||||
self.position_embedding = nn.Embedding(config.num_positions,
|
||||
config.hidden_size)
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
images : torch.Tensor
|
||||
Input image tensor with shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
Transformed tensor with shape (B, L, D)
|
||||
"""
|
||||
images = images.to(device=self.proj.weight.device,
|
||||
dtype=self.proj.weight.dtype)
|
||||
x = self.proj(images)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x += self.position_embedding.weight.unsqueeze(0)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_rank = config.num_heads // self.tp_size
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_dim,
|
||||
config.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
||||
self.scale)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
out = self.attn(q, k, v)
|
||||
output, _ = self.dense(out)
|
||||
output = self.output_dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.fc1(x)
|
||||
x = self.activation_fn(x)
|
||||
x, _ = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.mlp = MLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
attention_input = hidden_states
|
||||
attention_output = self.input_layernorm(
|
||||
self.attention(attention_input))
|
||||
hidden_states = attention_input + attention_output
|
||||
mlp_input = hidden_states
|
||||
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
|
||||
output = mlp_input + mlp_output
|
||||
return output
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for layer_module in self.layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
in_features,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
"""
|
||||
The original implementation is the same as:
|
||||
```python
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
```
|
||||
```
|
||||
gate_proj_output, _ = self.gate_proj(x)
|
||||
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
|
||||
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
|
||||
```
|
||||
|
||||
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
|
||||
```
|
||||
self.merged_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
```
|
||||
```
|
||||
x, _ = self.merged_proj(x)
|
||||
```
|
||||
"""
|
||||
super().__init__()
|
||||
self.linear_proj = ReplicatedLinear(in_features,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size)
|
||||
self.act1 = nn.GELU()
|
||||
self.act2 = SiluAndMul()
|
||||
|
||||
self.merged_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size, [config.ffn_hidden_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_proj")
|
||||
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h")
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.linear_proj(x)
|
||||
x = self.act1(self.norm1(x))
|
||||
x, _ = self.merged_proj(x)
|
||||
x = self.act2(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
class EVA2CLIPModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
vision_config = Namespace(**config.vision_config)
|
||||
self.patch_embedding = PatchEmbedding(vision_config)
|
||||
self.transformer = Transformer(vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.linear_proj = GLU(config,
|
||||
in_features=config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=2,
|
||||
stride=2)
|
||||
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.scaling_factor = vision_config.scaling_factor
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
images : torch.Tensor
|
||||
Input image tensor with shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
Transformed tensor with shape (B, L, D)
|
||||
"""
|
||||
x = self.patch_embedding(images)
|
||||
x = self.transformer(x)
|
||||
x = x[:, 1:]
|
||||
|
||||
b, s, h = x.shape
|
||||
grid_size = int(s**0.5)
|
||||
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
||||
x = self.conv(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.linear_proj(x)
|
||||
boi = self.boi.expand(x.shape[0], -1, -1)
|
||||
eoi = self.eoi.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((boi, x, eoi), dim=1)
|
||||
x = x / self.scaling_factor
|
||||
return x
|
||||
662
vllm/model_executor/models/glm4v.py
Normal file
662
vllm/model_executor/models/glm4v.py
Normal file
@ -0,0 +1,662 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import List, Literal, Mapping, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BatchFeature,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class EVA2CLIPPatchEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(config.in_channels,
|
||||
config.hidden_size,
|
||||
kernel_size=config.patch_size,
|
||||
stride=config.patch_size)
|
||||
self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
|
||||
self.position_embedding = nn.Embedding(config.num_positions,
|
||||
config.hidden_size)
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
images : torch.Tensor
|
||||
Input image tensor with shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
Transformed tensor with shape (B, L, D)
|
||||
"""
|
||||
images = images.to(device=self.proj.weight.device,
|
||||
dtype=self.proj.weight.dtype)
|
||||
x = self.proj(images)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x += self.position_embedding.weight.unsqueeze(0)
|
||||
return x
|
||||
|
||||
|
||||
class EVA2CLIPAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_rank = config.num_heads // self.tp_size
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_dim,
|
||||
config.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
|
||||
self.scale)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(x) # B, L, 3 * H * D
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
out = self.attn(q, k, v)
|
||||
output, _ = self.dense(out)
|
||||
output = self.output_dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class EVA2CLIPMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.fc1(x)
|
||||
x = self.activation_fn(x)
|
||||
x, _ = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class EVA2CLIPTransformerLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = EVA2CLIPAttention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.mlp = EVA2CLIPMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.post_attention_layernorm = LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
attention_input = hidden_states
|
||||
attention_output = self.input_layernorm(
|
||||
self.attention(attention_input))
|
||||
hidden_states = attention_input + attention_output
|
||||
mlp_input = hidden_states
|
||||
mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
|
||||
output = mlp_input + mlp_output
|
||||
return output
|
||||
|
||||
|
||||
class EVA2CLIPTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
EVA2CLIPTransformerLayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for layer_module in self.layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EVA2CLIPGLU(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
in_features,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
"""
|
||||
The original implementation is the same as:
|
||||
```python
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
```
|
||||
```
|
||||
gate_proj_output, _ = self.gate_proj(x)
|
||||
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
|
||||
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
|
||||
```
|
||||
|
||||
We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
|
||||
```
|
||||
self.merged_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config
|
||||
)
|
||||
```
|
||||
```
|
||||
x, _ = self.merged_proj(x)
|
||||
```
|
||||
"""
|
||||
super().__init__()
|
||||
self.linear_proj = ReplicatedLinear(in_features,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size)
|
||||
self.act1 = nn.GELU()
|
||||
self.act2 = SiluAndMul()
|
||||
|
||||
self.merged_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size, [config.ffn_hidden_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_proj")
|
||||
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h")
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.linear_proj(x)
|
||||
x = self.act1(self.norm1(x))
|
||||
x, _ = self.merged_proj(x)
|
||||
x = self.act2(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
class EVA2CLIPModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '',
|
||||
):
|
||||
super().__init__()
|
||||
vision_config = Namespace(**config.vision_config)
|
||||
self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config)
|
||||
self.transformer = EVA2CLIPTransformer(vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer")
|
||||
self.linear_proj = EVA2CLIPGLU(config,
|
||||
in_features=config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_proj")
|
||||
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=2,
|
||||
stride=2)
|
||||
self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.scaling_factor = vision_config.scaling_factor
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters:
|
||||
images : torch.Tensor
|
||||
Input image tensor with shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
Transformed tensor with shape (B, L, D)
|
||||
"""
|
||||
x = self.patch_embedding(images)
|
||||
x = self.transformer(x)
|
||||
x = x[:, 1:]
|
||||
|
||||
b, s, h = x.shape
|
||||
grid_size = int(s**0.5)
|
||||
x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
|
||||
x = self.conv(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.linear_proj(x)
|
||||
boi = self.boi.expand(x.shape[0], -1, -1)
|
||||
eoi = self.eoi.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((boi, x, eoi), dim=1)
|
||||
x = x / self.scaling_factor
|
||||
return x
|
||||
|
||||
|
||||
class GLM4VModel(ChatGLMModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vision = EVA2CLIPModel(self.config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision")
|
||||
|
||||
|
||||
class GLM4VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
vision_config = config.vision_config
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
return tokenizer
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
def get_hf_processor(self) -> GLM4VProcessor:
|
||||
return GLM4VProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_feature_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
def get_num_image_feature_tokens(self) -> int:
|
||||
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
|
||||
return self.get_num_image_tokens() + 2
|
||||
|
||||
|
||||
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=base_text * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
image_token_id = hf_config.pad_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [boi_token_id] + image_tokens + [eoi_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[boi_token_id, image_token_id, eoi_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
# vision
|
||||
"fc1",
|
||||
"fc2",
|
||||
"merged_proj",
|
||||
"linear_proj"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.encoder",
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
transformer_type: type[GLM4VModel] = GLM4VModel,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
|
||||
self.transformer: GLM4VModel
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config["image_size"]
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[GLMVImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return GLMVImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["data"].to(dtype=self.config.torch_dtype)
|
||||
|
||||
return self.transformer.vision(pixel_values)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
@ -6,381 +6,35 @@
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import re
|
||||
import unicodedata
|
||||
from functools import lru_cache, partial
|
||||
from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable,
|
||||
List, Literal, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
|
||||
TensorType)
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 3, image_size, image_size)`
|
||||
|
||||
Note that image_size is the value in the vision config to which we resize
|
||||
the image to in the normalization transform. Currently multi-image support
|
||||
can only be leveraged by passing image embeddings directly.
|
||||
"""
|
||||
|
||||
|
||||
class QwenImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, 256, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of the language model backbone
|
||||
and is stored in the visual config of the model if we have one.
|
||||
"""
|
||||
|
||||
|
||||
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
||||
|
||||
|
||||
class VisualAttention(nn.Module):
|
||||
"""self-attention layer class.
|
||||
Self-attention layer takes input with size [s, b, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
kdim: Optional[int] = None,
|
||||
vdim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim \
|
||||
and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
assert embed_dim % num_heads == 0
|
||||
self.hidden_size_per_attention_head = embed_dim // num_heads
|
||||
self.num_attention_heads_per_partition = num_heads
|
||||
self.hidden_size_per_partition = embed_dim
|
||||
|
||||
# Strided linear layer.
|
||||
assert self._qkv_same_embed_dim, \
|
||||
'Visual Attention implementation only supports self-attention'
|
||||
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# query/key/value: [sq, b, h]
|
||||
sq, b, _ = x.size()
|
||||
mixed_x_layer, _ = self.in_proj(x)
|
||||
|
||||
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
query_layer, key_layer, value_layer = mixed_x_layer.split(
|
||||
self.hidden_size_per_attention_head, dim=-1)
|
||||
|
||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||
query_layer = query_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||
key_layer = key_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
|
||||
q_scaled = query_layer / self.norm_factor
|
||||
if attn_mask is not None:
|
||||
attention_probs = torch.baddbmm(attn_mask, q_scaled,
|
||||
key_layer.transpose(-2, -1))
|
||||
else:
|
||||
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
|
||||
attention_probs = attention_probs.softmax(dim=-1)
|
||||
|
||||
value_layer = value_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
|
||||
# matmul: [b * np, sq, hn]
|
||||
context_layer = torch.bmm(attention_probs, value_layer)
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(
|
||||
b, self.num_attention_heads_per_partition, sq,
|
||||
self.hidden_size_per_attention_head)
|
||||
|
||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||
(self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class QwenVMLP(nn.Module):
|
||||
"""MLP for the visual component of the Qwen model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.act_fn = get_act_fn("gelu")
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.c_fc(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.attn = VisualAttention(d_model, n_head)
|
||||
self.mlp = QwenVMLP(
|
||||
hidden_size=d_model,
|
||||
intermediate_size=mlp_width,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
||||
return self.attn(x, attn_mask=attn_mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
|
||||
self.resblocks = nn.ModuleList([
|
||||
VisualAttentionBlock(width,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def get_cast_dtype(self) -> torch.dtype:
|
||||
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||
|
||||
def get_cast_device(self) -> torch.device:
|
||||
return self.resblocks[0].mlp.c_fc.weight.device
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
n_queries: int = 256,
|
||||
output_dim: int = 512,
|
||||
image_start_id: int = 151857,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
image_height, image_width = self.image_size = (image_size, image_size)
|
||||
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
|
||||
self.grid_size = (image_height // patch_height,
|
||||
image_width // patch_width)
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
scale = width**-0.5
|
||||
self.positional_embedding = nn.Parameter(scale *
|
||||
torch.randn(256, width))
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.ln_pre = norm_layer(width)
|
||||
self.transformer = TransformerBlock(width,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.attn_pool = Resampler2(
|
||||
grid_size=int(math.sqrt(n_queries)),
|
||||
embed_dim=output_dim,
|
||||
num_heads=output_dim // 128,
|
||||
kv_dim=width,
|
||||
norm_layer=norm_layer,
|
||||
adaptive=False,
|
||||
do_post_projection=False,
|
||||
).to(
|
||||
device=self.positional_embedding.device,
|
||||
dtype=self.positional_embedding.dtype,
|
||||
)
|
||||
|
||||
self.ln_post = norm_layer(output_dim)
|
||||
self.proj = nn.Parameter(
|
||||
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
|
||||
|
||||
self.image_start_id = image_start_id
|
||||
self.image_end_id = image_start_id + 1
|
||||
self.image_pad_id = image_start_id + 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(
|
||||
dtype=self.transformer.get_cast_dtype(),
|
||||
device=self.transformer.get_cast_device(),
|
||||
)
|
||||
|
||||
# to patches
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
|
||||
x.size(1))))
|
||||
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.attn_pool(x)
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@ -564,12 +218,6 @@ class QWenModel(nn.Module):
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
if (vision_config := getattr(config, "visual", None)):
|
||||
self.visual = VisionTransformer(**vision_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.visual = None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
@ -592,6 +240,7 @@ class QWenModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states, residual = layer(
|
||||
@ -610,302 +259,25 @@ class QWenModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_tokenizer_without_image_pad(
|
||||
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
||||
"""
|
||||
The logic of adding image pad tokens should only be applied in
|
||||
:class:`QWenVLProcessor`, so they are patched out here.
|
||||
|
||||
The definition of the wrapped tokenizer can be found here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
||||
"""
|
||||
new_tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
||||
|
||||
def tokenize(
|
||||
self,
|
||||
text: str,
|
||||
allowed_special: Union[AbstractSet[str], str] = "all",
|
||||
disallowed_special: Union[Collection[str], str] = (),
|
||||
**kwargs,
|
||||
) -> list[Union[bytes, str]]:
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
|
||||
return [
|
||||
self.decoder[t] for t in self.tokenizer.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
]
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
errors: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
return self.tokenizer.decode(
|
||||
token_ids,
|
||||
errors=errors or self.errors,
|
||||
)
|
||||
|
||||
TokenizerWithoutImagePad.__name__ = \
|
||||
f"{tokenizer.__class__.__name__}WithoutImagePad"
|
||||
|
||||
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
||||
return new_tokenizer
|
||||
|
||||
|
||||
class QWenVLProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
We call the wrapped tokenizer to automatically insert image pad tokens:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
|
||||
|
||||
The image processor is defined here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
|
||||
"""
|
||||
class QWenBaseModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
transformer_type: type[QWenModel] = QWenModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if vision_config := getattr(self.config, "visual", None):
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
else:
|
||||
self.image_transform = None
|
||||
|
||||
@property
|
||||
def image_start_tag(self) -> str:
|
||||
return self.tokenizer.image_start_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_end_tag(self) -> str:
|
||||
return self.tokenizer.image_end_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_pad_tag(self) -> str:
|
||||
return self.tokenizer.image_pad_tag # type: ignore
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
if self.image_transform is None:
|
||||
raise ValueError("This model does not support image inputs")
|
||||
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class QWenVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return _get_tokenizer_without_image_pad(tokenizer)
|
||||
|
||||
def get_hf_processor(self) -> QWenVLProcessor:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return QWenVLProcessor(self.get_hf_config(), tokenizer)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||
return 0
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||
return ProcessorInputs(prompt_text="", mm_data={})
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start = processor.image_start_tag
|
||||
img_end = processor.image_end_tag
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
|
||||
for i in range(1, num_images + 1)),
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Drops anything between <img>/</img> tags; encoding with the tokenizer
|
||||
# will automatically add the image pads for the context.
|
||||
prompt, num_matched_images = re.subn(
|
||||
r"(Picture \d*: <img>).*?(<\/img>\n)",
|
||||
r"\1\2",
|
||||
prompt,
|
||||
)
|
||||
|
||||
image_data = mm_data.get("images")
|
||||
if image_data is not None:
|
||||
assert isinstance(image_data, list)
|
||||
|
||||
num_images = len(image_data)
|
||||
if num_matched_images != num_images:
|
||||
logger.warning(
|
||||
"Number of matched image placeholders %s doesn't match "
|
||||
"the number of expected images %s; check your placeholder "
|
||||
"formatting.", num_matched_images, num_images)
|
||||
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not hasattr(hf_config, "visual"):
|
||||
return []
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens: dict[str,
|
||||
int] = tokenizer.special_tokens # type: ignore
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start_id = special_tokens[processor.image_start_tag]
|
||||
img_end_id = special_tokens[processor.image_end_tag]
|
||||
img_pad_id = special_tokens[processor.image_pad_tag]
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [img_pad_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[img_start_id, img_end_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full=[img_start_id] + image_tokens + [img_end_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = QWenModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.transformer = transformer_type(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
@ -916,104 +288,6 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.visual["image_size"]
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return QwenImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return QwenImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: QwenImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.transformer.visual is not None
|
||||
return self.transformer.visual(image_input["data"])
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.transformer.visual is not None
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.transformer.visual.image_pad_id)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1072,7 +346,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class QWenLLM(QWenBaseModel):
|
||||
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
@ -1090,76 +364,30 @@ class QWenLLM(QWenBaseModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
||||
class QWenVL(QWenBaseModel, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
"w2",
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
# visual module
|
||||
"out_proj",
|
||||
"in_proj",
|
||||
"c_fc",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.h",
|
||||
connector="transformer.visual.attn_pool",
|
||||
tower_model="transformer.visual.transformer")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor,
|
||||
info=QWenVLProcessingInfo,
|
||||
dummy_inputs=QWenVLDummyInputsBuilder)
|
||||
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||
"""
|
||||
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
|
||||
conducive to the current integration logic of LoRA in vLLM. Therefore, it
|
||||
is necessary to separate them.
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> QWenBaseModel:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if hasattr(config, "visual"):
|
||||
hf_overrides = {
|
||||
"architectures": ["QwenVLForConditionalGeneration"]
|
||||
}
|
||||
raise RuntimeError(
|
||||
"The configuration of this model indicates that it supports "
|
||||
"vision inputs, but you instantiated the text-only version "
|
||||
"of this model. Please use the vision model by setting "
|
||||
f"`--hf-overrides {hf_overrides!r}`")
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "visual"): # noqa: SIM108
|
||||
instance_cls = QWenVL
|
||||
# Initialize LLM
|
||||
else:
|
||||
instance_cls = QWenLLM
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
794
vllm/model_executor/models/qwen_vl.py
Normal file
794
vllm/model_executor/models/qwen_vl.py
Normal file
@ -0,0 +1,794 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import re
|
||||
import unicodedata
|
||||
from functools import lru_cache, partial
|
||||
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping,
|
||||
Optional, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
|
||||
TensorType)
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .qwen import QWenBaseModel, QWenModel
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 3, image_size, image_size)`
|
||||
|
||||
Note that image_size is the value in the vision config to which we resize
|
||||
the image to in the normalization transform. Currently multi-image support
|
||||
can only be leveraged by passing image embeddings directly.
|
||||
"""
|
||||
|
||||
|
||||
class QwenImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, 256, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of the language model backbone
|
||||
and is stored in the visual config of the model if we have one.
|
||||
"""
|
||||
|
||||
|
||||
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
||||
|
||||
|
||||
class VisualAttention(nn.Module):
|
||||
"""self-attention layer class.
|
||||
Self-attention layer takes input with size [s, b, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
kdim: Optional[int] = None,
|
||||
vdim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim \
|
||||
and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
assert embed_dim % num_heads == 0
|
||||
self.hidden_size_per_attention_head = embed_dim // num_heads
|
||||
self.num_attention_heads_per_partition = num_heads
|
||||
self.hidden_size_per_partition = embed_dim
|
||||
|
||||
# Strided linear layer.
|
||||
assert self._qkv_same_embed_dim, \
|
||||
'Visual Attention implementation only supports self-attention'
|
||||
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# query/key/value: [sq, b, h]
|
||||
sq, b, _ = x.size()
|
||||
mixed_x_layer, _ = self.in_proj(x)
|
||||
|
||||
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
||||
query_layer, key_layer, value_layer = mixed_x_layer.split(
|
||||
self.hidden_size_per_attention_head, dim=-1)
|
||||
|
||||
# [sq, b, np, hn] -> [sq, b * np, hn]
|
||||
query_layer = query_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
# [sk, b, np, hn] -> [sk, b * np, hn]
|
||||
key_layer = key_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
|
||||
q_scaled = query_layer / self.norm_factor
|
||||
if attn_mask is not None:
|
||||
attention_probs = torch.baddbmm(attn_mask, q_scaled,
|
||||
key_layer.transpose(-2, -1))
|
||||
else:
|
||||
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
|
||||
attention_probs = attention_probs.softmax(dim=-1)
|
||||
|
||||
value_layer = value_layer.view(
|
||||
sq, b * self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head).transpose(0, 1)
|
||||
|
||||
# matmul: [b * np, sq, hn]
|
||||
context_layer = torch.bmm(attention_probs, value_layer)
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(
|
||||
b, self.num_attention_heads_per_partition, sq,
|
||||
self.hidden_size_per_attention_head)
|
||||
|
||||
# [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||
(self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output, _ = self.out_proj(context_layer)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class QwenVLMLP(nn.Module):
|
||||
"""MLP for the visual component of the Qwen model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.act_fn = get_act_fn("gelu")
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.c_fc(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.attn = VisualAttention(d_model, n_head)
|
||||
self.mlp = QwenVLMLP(
|
||||
hidden_size=d_model,
|
||||
intermediate_size=mlp_width,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
||||
return self.attn(x, attn_mask=attn_mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
|
||||
self.resblocks = nn.ModuleList([
|
||||
VisualAttentionBlock(width,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def get_cast_dtype(self) -> torch.dtype:
|
||||
return self.resblocks[0].mlp.c_fc.weight.dtype
|
||||
|
||||
def get_cast_device(self) -> torch.device:
|
||||
return self.resblocks[0].mlp.c_fc.weight.device
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
for r in self.resblocks:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
n_queries: int = 256,
|
||||
output_dim: int = 512,
|
||||
image_start_id: int = 151857,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
image_height, image_width = self.image_size = (image_size, image_size)
|
||||
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
|
||||
self.grid_size = (image_height // patch_height,
|
||||
image_width // patch_width)
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
scale = width**-0.5
|
||||
self.positional_embedding = nn.Parameter(scale *
|
||||
torch.randn(256, width))
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.ln_pre = norm_layer(width)
|
||||
self.transformer = TransformerBlock(width,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config)
|
||||
|
||||
self.attn_pool = Resampler2(
|
||||
grid_size=int(math.sqrt(n_queries)),
|
||||
embed_dim=output_dim,
|
||||
num_heads=output_dim // 128,
|
||||
kv_dim=width,
|
||||
norm_layer=norm_layer,
|
||||
adaptive=False,
|
||||
do_post_projection=False,
|
||||
).to(
|
||||
device=self.positional_embedding.device,
|
||||
dtype=self.positional_embedding.dtype,
|
||||
)
|
||||
|
||||
self.ln_post = norm_layer(output_dim)
|
||||
self.proj = nn.Parameter(
|
||||
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
|
||||
|
||||
self.image_start_id = image_start_id
|
||||
self.image_end_id = image_start_id + 1
|
||||
self.image_pad_id = image_start_id + 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.to(
|
||||
dtype=self.transformer.get_cast_dtype(),
|
||||
device=self.transformer.get_cast_device(),
|
||||
)
|
||||
|
||||
# to patches
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
|
||||
x.size(1))))
|
||||
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.attn_pool(x)
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class QwenVLModel(QWenModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.visual = VisionTransformer(**config.visual,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_tokenizer_without_image_pad(
|
||||
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
||||
"""
|
||||
The logic of adding image pad tokens should only be applied in
|
||||
:class:`QwenVLProcessor`, so they are patched out here.
|
||||
|
||||
The definition of the wrapped tokenizer can be found here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
||||
"""
|
||||
new_tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
||||
|
||||
def tokenize(
|
||||
self,
|
||||
text: str,
|
||||
allowed_special: Union[AbstractSet[str], str] = "all",
|
||||
disallowed_special: Union[Collection[str], str] = (),
|
||||
**kwargs,
|
||||
) -> list[Union[bytes, str]]:
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
|
||||
return [
|
||||
self.decoder[t] for t in self.tokenizer.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
]
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
errors: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
return self.tokenizer.decode(
|
||||
token_ids,
|
||||
errors=errors or self.errors,
|
||||
)
|
||||
|
||||
TokenizerWithoutImagePad.__name__ = \
|
||||
f"{tokenizer.__class__.__name__}WithoutImagePad"
|
||||
|
||||
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
||||
return new_tokenizer
|
||||
|
||||
|
||||
class QwenVLProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
We call the wrapped tokenizer to automatically insert image pad tokens:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
|
||||
|
||||
The image processor is defined here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
vision_config = config.visual
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
|
||||
@property
|
||||
def image_start_tag(self) -> str:
|
||||
return self.tokenizer.image_start_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_end_tag(self) -> str:
|
||||
return self.tokenizer.image_end_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_pad_tag(self) -> str:
|
||||
return self.tokenizer.image_pad_tag # type: ignore
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return _get_tokenizer_without_image_pad(tokenizer)
|
||||
|
||||
def get_hf_processor(self) -> QwenVLProcessor:
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
|
||||
return QwenVLProcessor(self.get_hf_config(), tokenizer)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.visual
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.visual
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start = processor.image_start_tag
|
||||
img_end = processor.image_end_tag
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
|
||||
for i in range(1, num_images + 1)),
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Drops anything between <img>/</img> tags; encoding with the tokenizer
|
||||
# will automatically add the image pads for the context.
|
||||
prompt, num_matched_images = re.subn(
|
||||
r"(Picture \d*: <img>).*?(<\/img>\n)",
|
||||
r"\1\2",
|
||||
prompt,
|
||||
)
|
||||
|
||||
image_data = mm_data.get("images")
|
||||
if image_data is not None:
|
||||
assert isinstance(image_data, list)
|
||||
|
||||
num_images = len(image_data)
|
||||
assert num_matched_images == num_images
|
||||
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens: dict[str,
|
||||
int] = tokenizer.special_tokens # type: ignore
|
||||
|
||||
processor = self.info.get_hf_processor()
|
||||
img_start_id = special_tokens[processor.image_start_tag]
|
||||
img_end_id = special_tokens[processor.image_end_tag]
|
||||
img_pad_id = special_tokens[processor.image_pad_tag]
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [img_pad_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[img_start_id, img_end_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full=[img_start_id] + image_tokens + [img_end_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(QwenVLMultiModalProcessor,
|
||||
info=QwenVLProcessingInfo,
|
||||
dummy_inputs=QwenVLDummyInputsBuilder)
|
||||
class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
"w2",
|
||||
"w1",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"c_attn",
|
||||
"gate_up_proj",
|
||||
"c_proj",
|
||||
# visual module
|
||||
"out_proj",
|
||||
"in_proj",
|
||||
"c_fc",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.h",
|
||||
connector="transformer.visual.attn_pool",
|
||||
tower_model="transformer.visual.transformer")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
transformer_type: type[QwenVLModel] = QwenVLModel,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
|
||||
self.transformer: QwenVLModel
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.visual["image_size"]
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return QwenImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return QwenImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: QwenImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
return self.transformer.visual(image_input["data"])
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.transformer.visual.image_pad_id)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
@ -39,7 +39,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
||||
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
# ChatGLMModel supports multimodal
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||
@ -90,7 +90,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
# QWenLMHeadModel supports multimodal
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
@ -156,10 +156,9 @@ _MULTIMODAL_MODELS = {
|
||||
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
@ -175,7 +174,7 @@ _MULTIMODAL_MODELS = {
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
|
||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user