mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:04:58 +08:00
[Bugfix] Add phi3v resize for dynamic shape and fix torchvision requirement (#5772)
This commit is contained in:
parent
5d4d90536f
commit
edd5fe5fa2
@ -3,4 +3,5 @@
|
|||||||
|
|
||||||
# Dependencies for x86_64 CPUs
|
# Dependencies for x86_64 CPUs
|
||||||
torch == 2.3.1+cpu
|
torch == 2.3.1+cpu
|
||||||
|
torchvision == 0.18.1+cpu # required for the image processor of phi3v, this must be updated alongside torch
|
||||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||||
@ -5,5 +5,7 @@
|
|||||||
ray >= 2.9
|
ray >= 2.9
|
||||||
nvidia-ml-py # for pynvml package
|
nvidia-ml-py # for pynvml package
|
||||||
torch == 2.3.0
|
torch == 2.3.0
|
||||||
|
# These must be updated alongside torch
|
||||||
|
torchvision == 0.18.0 # Required for phi3v processor, also see https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||||
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
|
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
|
||||||
|
|||||||
@ -14,7 +14,6 @@ peft
|
|||||||
requests
|
requests
|
||||||
ray
|
ray
|
||||||
sentence-transformers # required for embedding
|
sentence-transformers # required for embedding
|
||||||
torchvision # required for the image processor of phi3v
|
|
||||||
|
|
||||||
# Benchmarking
|
# Benchmarking
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|||||||
@ -22,6 +22,7 @@ assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
|||||||
def iter_phi3v_configs(model_name: str):
|
def iter_phi3v_configs(model_name: str):
|
||||||
image_hw_to_feature_size = {
|
image_hw_to_feature_size = {
|
||||||
(1008, 1344): 1921,
|
(1008, 1344): 1921,
|
||||||
|
(2016, 2688): 1933,
|
||||||
}
|
}
|
||||||
|
|
||||||
for (h, w), f in image_hw_to_feature_size.items():
|
for (h, w), f in image_hw_to_feature_size.items():
|
||||||
@ -75,6 +76,9 @@ if is_cpu():
|
|||||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||||
# Since we use _attn_implementation="eager" for hf_runner, here is
|
# Since we use _attn_implementation="eager" for hf_runner, here is
|
||||||
# numeric difference for longer context and test can't pass
|
# numeric difference for longer context and test can't pass
|
||||||
|
@pytest.mark.xfail(
|
||||||
|
reason="Inconsistent image processor being used due to lack "
|
||||||
|
"of support for dynamic image token replacement")
|
||||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
|||||||
@ -13,14 +13,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from PIL import Image
|
||||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
@ -32,9 +35,11 @@ from vllm.model_executor.models.llama import LlamaModel
|
|||||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.image import get_dummy_image_data
|
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
"model.vision_embed_tokens": "vision_embed_tokens",
|
"model.vision_embed_tokens": "vision_embed_tokens",
|
||||||
}
|
}
|
||||||
@ -268,7 +273,63 @@ class Phi3VImagePixelInputs(TypedDict):
|
|||||||
"""Shape: (batch_size, 2)"""
|
"""Shape: (batch_size, 2)"""
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
|
||||||
|
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||||
|
def calc_padded_size(width, height, padding_unit=336):
|
||||||
|
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||||
|
top_padding = int((target_height - height) / 2)
|
||||||
|
bottom_padding = target_height - height - top_padding
|
||||||
|
padded_width = width
|
||||||
|
padded_height = height + top_padding + bottom_padding
|
||||||
|
return padded_width, padded_height
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||||
|
def calc_hd_transform_size(width, height, hd_num=16):
|
||||||
|
transposed = False
|
||||||
|
if width < height:
|
||||||
|
width, height = height, width
|
||||||
|
transposed = True
|
||||||
|
|
||||||
|
ratio = width / height
|
||||||
|
scale = 1
|
||||||
|
while scale * np.ceil(scale / ratio) <= hd_num:
|
||||||
|
scale += 1
|
||||||
|
scale -= 1
|
||||||
|
|
||||||
|
new_width = int(scale * 336)
|
||||||
|
new_height = int(new_width / ratio)
|
||||||
|
|
||||||
|
padded_width, padded_height = calc_padded_size(new_width, new_height)
|
||||||
|
|
||||||
|
if transposed:
|
||||||
|
padded_width, padded_height = padded_height, padded_width
|
||||||
|
|
||||||
|
return padded_width, padded_height
|
||||||
|
|
||||||
|
|
||||||
|
def _image_processor(
|
||||||
|
data: ImagePixelData,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
vlm_config: VisionLanguageConfig,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
image = data.image
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
# Temporary patch before dynamic number of image tokens is supported
|
||||||
|
_, _, h, w = vlm_config.image_input_shape
|
||||||
|
if (w, h) != calc_hd_transform_size(image.width, image.height):
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic image shape is currently not supported. "
|
||||||
|
"Resizing input image to (%d, %d).", w, h)
|
||||||
|
|
||||||
|
data.image = image.resize((w, h))
|
||||||
|
|
||||||
|
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||||
|
._default_input_processor(data, model_config, vlm_config)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
|
||||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||||
class Phi3VForCausalLM(VisionLanguageModelBase):
|
class Phi3VForCausalLM(VisionLanguageModelBase):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user