mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:15:01 +08:00
parent
775f00f81e
commit
d394787e52
@ -247,6 +247,11 @@ Multimodal Language Models
|
|||||||
- Image\ :sup:`E+`
|
- Image\ :sup:`E+`
|
||||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`PixtralForConditionalGeneration`
|
||||||
|
- Pixtral
|
||||||
|
- Image\ :sup:`+`
|
||||||
|
- :code:`mistralai/Pixtral-12B-2409`
|
||||||
|
-
|
||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
- Qwen-VL
|
- Qwen-VL
|
||||||
- Image\ :sup:`E`
|
- Image\ :sup:`E`
|
||||||
|
|||||||
164
examples/offline_inference_pixtral.py
Normal file
164
examples/offline_inference_pixtral.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# ruff: noqa
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
# This script is an offline demo for running Pixtral.
|
||||||
|
#
|
||||||
|
# If you want to run a server/client setup, please follow this code:
|
||||||
|
#
|
||||||
|
# - Server:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# - Client:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
|
||||||
|
# --header 'Content-Type: application/json' \
|
||||||
|
# --header 'Authorization: Bearer token' \
|
||||||
|
# --data '{
|
||||||
|
# "model": "mistralai/Pixtral-12B-2409",
|
||||||
|
# "messages": [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {"type" : "text", "text": "Describe this image in detail please."},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
|
||||||
|
# {"type" : "text", "text": "and this one as well. Answer in French."},
|
||||||
|
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }'
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# python demo.py simple
|
||||||
|
# python demo.py advanced
|
||||||
|
|
||||||
|
|
||||||
|
def run_simple_demo():
|
||||||
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
|
sampling_params = SamplingParams(max_tokens=8192)
|
||||||
|
|
||||||
|
llm = LLM(model=model_name, tokenizer_mode="mistral")
|
||||||
|
|
||||||
|
prompt = "Describe this image in one sentence."
|
||||||
|
image_url = "https://picsum.photos/id/237/200/300"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(messages, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
print(outputs[0].outputs[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
def run_advanced_demo():
|
||||||
|
model_name = "mistralai/Pixtral-12B-2409"
|
||||||
|
max_img_per_msg = 5
|
||||||
|
max_tokens_per_img = 4096
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||||
|
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "Describe the following image."
|
||||||
|
|
||||||
|
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
|
||||||
|
url_2 = "https://picsum.photos/seed/picsum/200/300"
|
||||||
|
url_3 = "https://picsum.photos/id/32/512/512"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url_1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url_2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The images show nature.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "More details please and answer only in French!.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url_3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
|
||||||
|
print(outputs[0].outputs[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run a demo in simple or advanced mode.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"mode",
|
||||||
|
choices=["simple", "advanced"],
|
||||||
|
help="Specify the demo mode: 'simple' or 'advanced'",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "simple":
|
||||||
|
print("Running simple demo...")
|
||||||
|
run_simple_demo()
|
||||||
|
elif args.mode == "advanced":
|
||||||
|
print("Running advanced demo...")
|
||||||
|
run_advanced_demo()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -25,7 +25,7 @@ pyzmq
|
|||||||
msgspec
|
msgspec
|
||||||
gguf == 0.9.1
|
gguf == 0.9.1
|
||||||
importlib_metadata
|
importlib_metadata
|
||||||
mistral_common >= 1.3.4
|
mistral_common >= 1.4.0
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
einops # Required for Qwen2-VL.
|
einops # Required for Qwen2-VL.
|
||||||
|
|||||||
64
tests/models/test_pixtral.py
Normal file
64
tests/models/test_pixtral.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason=
|
||||||
|
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
image_urls = [
|
||||||
|
"https://picsum.photos/id/237/200/300",
|
||||||
|
"https://picsum.photos/seed/picsum/200/300"
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
|
||||||
|
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
|
||||||
|
]
|
||||||
|
prompt = "Describe the image in one short sentence."
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype,
|
||||||
|
tokenizer_mode="mistral") as vllm_model:
|
||||||
|
|
||||||
|
for i, image_url in enumerate(image_urls):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
}, {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = vllm_model.model.chat(messages,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
assert outputs[0].outputs[0].text == expected[i]
|
||||||
@ -148,7 +148,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return f"<|image_{current_count}|>"
|
return f"<|image_{current_count}|>"
|
||||||
if model_type == "minicpmv":
|
if model_type == "minicpmv":
|
||||||
return "(<image>./</image>)"
|
return "(<image>./</image>)"
|
||||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
|
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||||
|
"pixtral"):
|
||||||
# These models do not use image tokens in the prompt
|
# These models do not use image tokens in the prompt
|
||||||
return None
|
return None
|
||||||
if model_type == "qwen":
|
if model_type == "qwen":
|
||||||
|
|||||||
@ -92,6 +92,8 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
|
"PixtralForConditionalGeneration": ("pixtral",
|
||||||
|
"PixtralForConditionalGeneration"),
|
||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||||
"Qwen2VLForConditionalGeneration"),
|
"Qwen2VLForConditionalGeneration"),
|
||||||
}
|
}
|
||||||
|
|||||||
551
vllm/model_executor/models/pixtral.py
Normal file
551
vllm/model_executor/models/pixtral.py
Normal file
@ -0,0 +1,551 @@
|
|||||||
|
import math
|
||||||
|
from array import array
|
||||||
|
from dataclasses import dataclass, fields
|
||||||
|
from itertools import tee
|
||||||
|
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from mistral_common.protocol.instruct.messages import ImageChunk
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
|
SequenceData)
|
||||||
|
|
||||||
|
from .interfaces import SupportsMultiModal
|
||||||
|
from .utils import init_vllm_registered_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||||
|
tokenizer = cached_get_tokenizer(
|
||||||
|
ctx.model_config.tokenizer,
|
||||||
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||||
|
mm_encoder = tokenizer.instruct.mm_encoder
|
||||||
|
|
||||||
|
max_image_size = mm_encoder.mm_config.max_image_size
|
||||||
|
image_patch_size = mm_encoder.mm_config.image_patch_size
|
||||||
|
|
||||||
|
return ((max_image_size // image_patch_size)**2)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int]):
|
||||||
|
tokenizer = cached_get_tokenizer(
|
||||||
|
ctx.model_config.tokenizer,
|
||||||
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||||
|
mm_encoder = tokenizer.instruct.mm_encoder
|
||||||
|
|
||||||
|
mm_config = ctx.model_config.multimodal_config
|
||||||
|
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
|
||||||
|
|
||||||
|
# approximate image size
|
||||||
|
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
|
||||||
|
|
||||||
|
image = Image.new("RGB", (size, size), color=0)
|
||||||
|
img_chunk = ImageChunk(image=image)
|
||||||
|
|
||||||
|
tokens = mm_encoder(img_chunk).tokens
|
||||||
|
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
|
tokens)
|
||||||
|
|
||||||
|
seq_data = SequenceData(token_ids)
|
||||||
|
mm_data = {"image": max_num_images_per_request * [image]}
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
|
def input_mapper_for_pixtral(ctx: InputContext,
|
||||||
|
data: object) -> MultiModalInputs:
|
||||||
|
"""Maps the input data to its MultiModalInputs (if any).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context of the loaded model.
|
||||||
|
data: data potentially containing image/image embeddings to be mapped
|
||||||
|
to pixel_values in .forward() for a visual QWenLMHeadModel model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MultiModalInputs containing the stacked normalized images tensor or
|
||||||
|
image embeddings.
|
||||||
|
"""
|
||||||
|
# Early exit if we have provided an image to a language only Qwen model
|
||||||
|
model_config = ctx.model_config
|
||||||
|
tokenizer = cached_get_tokenizer(
|
||||||
|
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
|
||||||
|
|
||||||
|
data_list = data if isinstance(data, list) else [data]
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for image_data in data_list:
|
||||||
|
image = ImageChunk(image=image_data)
|
||||||
|
encoding = tokenizer.instruct.mm_encoder(image)
|
||||||
|
image = torch.from_numpy(encoding.image).to(device="cuda",
|
||||||
|
dtype=torch.float16)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
return MultiModalInputs({"images": images})
|
||||||
|
|
||||||
|
|
||||||
|
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: Optional[List[torch.Tensor]],
|
||||||
|
image_id: int) -> torch.Tensor:
|
||||||
|
text_locations = input_ids != image_id
|
||||||
|
image_locations = input_ids == image_id
|
||||||
|
|
||||||
|
seq_len = input_ids.shape[0]
|
||||||
|
|
||||||
|
N_txt = text_locations.sum().item()
|
||||||
|
_, D_txt = inputs_embeds.shape
|
||||||
|
N_img, D_img = image_features.shape
|
||||||
|
|
||||||
|
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
|
||||||
|
"to image features dim {D_img}")
|
||||||
|
assert (seq_len == N_txt +
|
||||||
|
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
|
||||||
|
f"{(N_txt, N_img, image_locations.sum().item())}")
|
||||||
|
|
||||||
|
inputs_embeds[image_locations, :] = image_features
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
||||||
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
|
||||||
|
vision_args = {
|
||||||
|
key: value
|
||||||
|
for key, value in self.config.vision_config.to_dict().items()
|
||||||
|
if key in dataclass_fields
|
||||||
|
}
|
||||||
|
|
||||||
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
||||||
|
|
||||||
|
# init MistralForCausalLM
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
config.text_config, cache_config, quant_config)
|
||||||
|
|
||||||
|
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||||
|
self.vision_language_adapter = VisionLanguageAdapter(
|
||||||
|
self.vision_args, dim=config.text_config.hidden_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
"""Run forward pass for pixtral.
|
||||||
|
|
||||||
|
TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
|
if image_input is not None:
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
|
self.vision_args.image_token_id)
|
||||||
|
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
None,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self,
|
||||||
|
images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor],
|
||||||
|
torch.Tensor]] = None
|
||||||
|
) -> Optional[List[torch.Tensor]]:
|
||||||
|
if images is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
# always take last images
|
||||||
|
images = [images[-1][i] for i in range(images.size(1))]
|
||||||
|
elif isinstance(images, list):
|
||||||
|
# always take last images
|
||||||
|
images = [images[-1][i] for i in range(len(images[0]))]
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def _process_image_input(self,
|
||||||
|
image_input: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.vision_language_adapter(self.vision_encoder(image_input))
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
|
||||||
|
return weight[0].startswith("vision_encoder")
|
||||||
|
|
||||||
|
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
||||||
|
return weight[0].startswith("vision_language_adapter")
|
||||||
|
|
||||||
|
def is_vision_weights(weight: Tuple[str, torch.Tensor]):
|
||||||
|
return is_vision_encoder_weights(
|
||||||
|
weight) or is_vision_lang_adapter_weights(weight)
|
||||||
|
|
||||||
|
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
|
||||||
|
weights, 3)
|
||||||
|
|
||||||
|
# llm
|
||||||
|
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
|
||||||
|
self.language_model.load_weights(llm_weights)
|
||||||
|
|
||||||
|
# vision encoder
|
||||||
|
vision_encoder_weights = filter(is_vision_encoder_weights,
|
||||||
|
vision_encoder_weights)
|
||||||
|
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
||||||
|
for name, loaded_weight in vision_encoder_weights:
|
||||||
|
# cut 'vision_encoder.'
|
||||||
|
name = '.'.join(name.split(".")[1:])
|
||||||
|
param = vision_encoder_dict[name]
|
||||||
|
|
||||||
|
default_weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# adapter
|
||||||
|
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
|
||||||
|
vision_lang_adapter_weights)
|
||||||
|
vision_lang_adpter_dict = dict(
|
||||||
|
self.vision_language_adapter.named_parameters())
|
||||||
|
for name, loaded_weight in vision_lang_adapter_weights:
|
||||||
|
# cut 'vision_language_adapter.'
|
||||||
|
name = '.'.join(name.split(".")[1:])
|
||||||
|
param = vision_lang_adpter_dict[name]
|
||||||
|
default_weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# Vision encoder
|
||||||
|
@dataclass
|
||||||
|
class VisionEncoderArgs:
|
||||||
|
hidden_size: int
|
||||||
|
num_channels: int
|
||||||
|
image_size: int
|
||||||
|
patch_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
num_attention_heads: int
|
||||||
|
rope_theta: float # for rope-2D
|
||||||
|
image_token_id: int
|
||||||
|
|
||||||
|
|
||||||
|
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
||||||
|
x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
freqs_cis: complex - (seq_len, head_dim / 2)
|
||||||
|
x: complex - (bsz, seq_len, head_dim / 2)
|
||||||
|
"""
|
||||||
|
ndim = x.ndim
|
||||||
|
assert ndim > 1
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
|
||||||
|
freqs_cis.shape,
|
||||||
|
(x.shape[1], x.shape[-1]),
|
||||||
|
)
|
||||||
|
shape = [
|
||||||
|
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
||||||
|
]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis_2d(
|
||||||
|
dim: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
theta: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
|
||||||
|
to be indexed by (height, width) position tuples
|
||||||
|
"""
|
||||||
|
# (dim / 2) frequency bases
|
||||||
|
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
||||||
|
|
||||||
|
h = torch.arange(height, device=freqs.device)
|
||||||
|
w = torch.arange(width, device=freqs.device)
|
||||||
|
|
||||||
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
||||||
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
||||||
|
freqs_2d = torch.cat(
|
||||||
|
[
|
||||||
|
freqs_h[:, None, :].repeat(1, width, 1),
|
||||||
|
freqs_w[None, :, :].repeat(height, 1, 1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_vit(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
|
assert freqs_cis.dtype == torch.complex64
|
||||||
|
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs):
|
||||||
|
super().__init__()
|
||||||
|
assert args.intermediate_size is not None
|
||||||
|
self.w1 = nn.Linear(args.hidden_size,
|
||||||
|
args.intermediate_size,
|
||||||
|
bias=False)
|
||||||
|
self.w2 = nn.Linear(args.intermediate_size,
|
||||||
|
args.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.w3 = nn.Linear(args.hidden_size,
|
||||||
|
args.intermediate_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
assert not args.hidden_size % args.num_attention_heads
|
||||||
|
self.n_heads = args.num_attention_heads
|
||||||
|
self.head_dim = args.hidden_size // args.num_attention_heads
|
||||||
|
|
||||||
|
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||||
|
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||||
|
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||||
|
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: BlockDiagonalMask,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch, patches, _ = x.shape
|
||||||
|
|
||||||
|
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
|
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||||
|
out = memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||||
|
return self.wo(out)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = FeedForward(args)
|
||||||
|
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
||||||
|
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: BlockDiagonalMask,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r = self.attention.forward(self.attention_norm(x),
|
||||||
|
mask=mask,
|
||||||
|
freqs_cis=freqs_cis)
|
||||||
|
h = x + r
|
||||||
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for _ in range(args.num_hidden_layers):
|
||||||
|
self.layers.append(TransformerBlock(args))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: BlockDiagonalMask,
|
||||||
|
freqs_cis: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
|
||||||
|
positions = torch.cat([
|
||||||
|
torch.stack(
|
||||||
|
torch.meshgrid(
|
||||||
|
torch.arange(p.shape[-2]),
|
||||||
|
torch.arange(p.shape[-1]),
|
||||||
|
indexing="ij",
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
).reshape(-1, 2) for p in patch_embeds_list
|
||||||
|
])
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
|
class VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.patch_conv = nn.Conv2d(
|
||||||
|
in_channels=args.num_channels,
|
||||||
|
out_channels=args.hidden_size,
|
||||||
|
kernel_size=args.patch_size,
|
||||||
|
stride=args.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
||||||
|
self.transformer = Transformer(args)
|
||||||
|
|
||||||
|
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||||
|
self._freqs_cis: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_patches_per_side(self) -> int:
|
||||||
|
return self.args.image_size // self.args.patch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.device:
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def freqs_cis(self) -> torch.Tensor:
|
||||||
|
if self._freqs_cis is None:
|
||||||
|
self._freqs_cis = precompute_freqs_cis_2d(
|
||||||
|
dim=self.args.hidden_size // self.args.num_attention_heads,
|
||||||
|
height=self.max_patches_per_side,
|
||||||
|
width=self.max_patches_per_side,
|
||||||
|
theta=self.args.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._freqs_cis.device != self.device:
|
||||||
|
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
||||||
|
|
||||||
|
return self._freqs_cis
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
images: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
images: list of N_img images of variable sizes,
|
||||||
|
each of shape (C, H, W)
|
||||||
|
Returns:
|
||||||
|
image_features: tensor of token features for
|
||||||
|
all tokens of all images of shape (N_toks, D)
|
||||||
|
"""
|
||||||
|
# pass images through initial convolution independently
|
||||||
|
patch_embeds_list = [
|
||||||
|
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
||||||
|
]
|
||||||
|
|
||||||
|
# flatten to a single sequence
|
||||||
|
patch_embeds = torch.cat(
|
||||||
|
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||||
|
patch_embeds = self.ln_pre(patch_embeds)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
||||||
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
||||||
|
|
||||||
|
# pass through Transformer with a block diagonal mask delimiting images
|
||||||
|
mask = BlockDiagonalMask.from_seqlens(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||||
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
# remove batch dimension of the single sequence
|
||||||
|
return out.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionLanguageAdapter(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args: VisionEncoderArgs, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(args, VisionEncoderArgs)
|
||||||
|
self.w_in = nn.Linear(
|
||||||
|
args.hidden_size,
|
||||||
|
dim,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.w_out = nn.Linear(dim, dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.w_out(self.gelu(self.w_in(x)))
|
||||||
@ -70,7 +70,7 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
|||||||
if Path(model).exists():
|
if Path(model).exists():
|
||||||
return (Path(model) / config_name).is_file()
|
return (Path(model) / config_name).is_file()
|
||||||
|
|
||||||
return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
return file_exists(model, config_name, revision=revision, token=token)
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@ -205,14 +205,25 @@ def load_params_config(model, revision) -> PretrainedConfig:
|
|||||||
config_dict["hidden_act"] = config_dict.get("activation", "silu")
|
config_dict["hidden_act"] = config_dict.get("activation", "silu")
|
||||||
config_dict["tie_word_embeddings"] = config_dict.get(
|
config_dict["tie_word_embeddings"] = config_dict.get(
|
||||||
"tie_embeddings", False)
|
"tie_embeddings", False)
|
||||||
|
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
|
||||||
|
|
||||||
if config_dict["model_type"] == "transformer":
|
if config_dict.get("moe") is not None:
|
||||||
if "moe" in config_dict:
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
else:
|
||||||
else:
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
|
||||||
|
|
||||||
return recurse_elems(config_dict)
|
if config_dict.get("vision_encoder") is not None:
|
||||||
|
multimodal_config = config_dict.pop("vision_encoder")
|
||||||
|
|
||||||
|
config_dict = {
|
||||||
|
"text_config": config_dict,
|
||||||
|
"vision_config": multimodal_config
|
||||||
|
}
|
||||||
|
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
|
||||||
|
config_dict["model_type"] = "pixtral"
|
||||||
|
|
||||||
|
config = recurse_elems(config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_hf_image_processor_config(
|
def get_hf_image_processor_config(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user