mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 14:05:39 +08:00
[Bugfix] Clean up and fix multi-modal processors (#13012)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fde71262e0
commit
51f0b5f7f6
@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
|
|||||||
* ✅
|
* ✅
|
||||||
* ✅
|
* ✅
|
||||||
* ?
|
* ?
|
||||||
* [✗](gh-issue:7968>)
|
* [✗](gh-issue:7968)
|
||||||
* ?
|
* ?
|
||||||
* ✅
|
* ✅
|
||||||
*
|
*
|
||||||
|
|||||||
@ -26,6 +26,9 @@ from ...utils import check_logprobs_close
|
|||||||
"google/gemma-1.1-2b-it", # gemma
|
"google/gemma-1.1-2b-it", # gemma
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
"THUDM/chatglm3-6b", # ChatGLM (text-only)
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
@ -43,6 +46,9 @@ from ...utils import check_logprobs_close
|
|||||||
"microsoft/phi-2", # phi
|
"microsoft/phi-2", # phi
|
||||||
marks=[pytest.mark.core_model],
|
marks=[pytest.mark.core_model],
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
"Qwen/Qwen-7B", # qwen (text-only)
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
|
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
|
||||||
marks=[pytest.mark.core_model],
|
marks=[pytest.mark.core_model],
|
||||||
@ -68,6 +74,10 @@ def test_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
if model.startswith("THUDM/chatglm3"):
|
||||||
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
|
hf_model.model.transformer.output_layer
|
||||||
|
|
||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def _test_processing_correctness(
|
|||||||
mm_data = {
|
mm_data = {
|
||||||
k:
|
k:
|
||||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||||
for _ in range(rng.randint(limit))]
|
for _ in range(rng.randint(limit + 1))]
|
||||||
for k, limit in limit_mm_per_prompt.items()
|
for k, limit in limit_mm_per_prompt.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -17,10 +17,7 @@ def random_video(
|
|||||||
min_wh: int,
|
min_wh: int,
|
||||||
max_wh: int,
|
max_wh: int,
|
||||||
):
|
):
|
||||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
|
||||||
num_frames = rng.randint(min_frames, max_frames)
|
num_frames = rng.randint(min_frames, max_frames)
|
||||||
num_frames = (num_frames // 2) * 2
|
|
||||||
|
|
||||||
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
w, h = rng.randint(min_wh, max_wh, size=(2, ))
|
||||||
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
|||||||
@ -4,8 +4,8 @@
|
|||||||
# https://github.com/THUDM/CogAgent
|
# https://github.com/THUDM/CogAgent
|
||||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
|
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||||
TypedDict, Union)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -19,7 +19,6 @@ from transformers.tokenization_utils_base import TextInput
|
|||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -37,12 +36,10 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BatchFeature,
|
BaseProcessingInfo, BatchFeature,
|
||||||
BoundPromptReplacement,
|
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
PlaceholderFeaturesInfo,
|
|
||||||
PromptReplacement)
|
PromptReplacement)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -53,39 +50,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
|||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
IMAGE_TOKEN_ID = 151329
|
|
||||||
|
|
||||||
|
|
||||||
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
|
||||||
"""
|
|
||||||
Build a normalization transform which can be applied to one or
|
|
||||||
more input images from which we want to extract visual features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_size: size of the image to be processed for visual embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable transform for normalizing and resizing one RGB image.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize(
|
|
||||||
(image_size, image_size),
|
|
||||||
interpolation=InterpolationMode.BICUBIC,
|
|
||||||
),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(
|
|
||||||
(0.48145466, 0.4578275, 0.40821073),
|
|
||||||
(0.26862954, 0.26130258, 0.27577711),
|
|
||||||
),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_image_placeholder(vision_config):
|
|
||||||
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
|
|
||||||
|
|
||||||
|
|
||||||
class GLMImagePixelInputs(TypedDict):
|
class GLMImagePixelInputs(TypedDict):
|
||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
@ -109,9 +73,20 @@ class GLM4VProcessor:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
if hasattr(self.config, "vision_config"):
|
if vision_config := getattr(config, "vision_config", None):
|
||||||
self.image_transform = build_normalization_transform(
|
image_size = vision_config["image_size"]
|
||||||
config.vision_config["image_size"])
|
|
||||||
|
self.image_transform = transforms.Compose([
|
||||||
|
transforms.Resize(
|
||||||
|
(image_size, image_size),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711),
|
||||||
|
),
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
self.image_transform = None
|
self.image_transform = None
|
||||||
|
|
||||||
@ -150,9 +125,19 @@ class GLM4VProcessor:
|
|||||||
|
|
||||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def __init__(self, ctx):
|
def get_tokenizer(self):
|
||||||
super().__init__(ctx)
|
tokenizer = self.ctx.tokenizer
|
||||||
self._pre_calculate()
|
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||||
|
|
||||||
|
def get_hf_processor(self) -> GLM4VProcessor:
|
||||||
|
return GLM4VProcessor(
|
||||||
|
self.get_hf_config(),
|
||||||
|
self.get_tokenizer(),
|
||||||
|
)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": 1}
|
return {"image": 1}
|
||||||
@ -162,27 +147,21 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_num_image_feature_tokens()}
|
||||||
return {"image": self.image_token_num + 2}
|
|
||||||
|
|
||||||
def _pre_calculate(self):
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
self.image_token_num = calculate_image_placeholder(vision_config)
|
|
||||||
self.image_size = vision_config["image_size"]
|
|
||||||
|
|
||||||
def get_num_image_tokens(self) -> int:
|
def get_num_image_tokens(self) -> int:
|
||||||
return self.image_token_num + 2
|
hf_config = self.get_hf_config()
|
||||||
|
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||||
|
return 0
|
||||||
|
|
||||||
def get_image_size(self) -> ImageSize:
|
image_size = vision_config["image_size"]
|
||||||
|
patch_size = vision_config["patch_size"]
|
||||||
|
grid_length = image_size // patch_size // 2
|
||||||
|
return grid_length * grid_length
|
||||||
|
|
||||||
return ImageSize(height=self.image_size, width=self.image_size)
|
def get_num_image_feature_tokens(self) -> int:
|
||||||
|
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
|
||||||
def get_hf_processor(self) -> GLM4VProcessor:
|
return self.get_num_image_tokens() + 2
|
||||||
return GLM4VProcessor(
|
|
||||||
self.get_hf_config(),
|
|
||||||
self.get_tokenizer(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||||
@ -192,8 +171,12 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||||
|
return ProcessorInputs(prompt_text="", mm_data={})
|
||||||
|
|
||||||
|
target_width = target_height = vision_config["image_size"]
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
target_width, target_height = self.info.get_image_size()
|
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -201,9 +184,11 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
|||||||
height=target_height,
|
height=target_height,
|
||||||
num_images=num_images)
|
num_images=num_images)
|
||||||
}
|
}
|
||||||
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
|
||||||
|
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||||
|
|
||||||
return ProcessorInputs(
|
return ProcessorInputs(
|
||||||
prompt_text=text,
|
prompt_text=base_text * num_images,
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -223,47 +208,28 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
if not hasattr(hf_config, "vision_config"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
boi_token_id = hf_config.boi_token_id
|
||||||
|
image_token_id = hf_config.pad_token_id
|
||||||
|
eoi_token_id = hf_config.eoi_token_id
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
def get_replacement(item_idx: int):
|
||||||
image_tokens = self.info.image_token_num
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
return [IMAGE_TOKEN_ID] * image_tokens
|
image_tokens = [image_token_id] * num_image_tokens
|
||||||
|
|
||||||
|
return [boi_token_id] + image_tokens + [eoi_token_id]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
modality="image",
|
modality="image",
|
||||||
target=[IMAGE_TOKEN_ID],
|
target=[boi_token_id, image_token_id, eoi_token_id],
|
||||||
replacement=get_replacement,
|
replacement=get_replacement,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _apply_prompt_replacements(
|
|
||||||
self,
|
|
||||||
token_ids: list[int],
|
|
||||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
|
||||||
mm_item_counts: Mapping[str, int],
|
|
||||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
|
||||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
|
||||||
token_ids=token_ids,
|
|
||||||
mm_prompt_repls=mm_prompt_repls,
|
|
||||||
mm_item_counts=mm_item_counts,
|
|
||||||
)
|
|
||||||
hf_config = self.info.get_hf_config()
|
|
||||||
boi_token_id = hf_config.boi_token_id
|
|
||||||
eoi_token_id = hf_config.eoi_token_id
|
|
||||||
placeholders = {
|
|
||||||
modality: [
|
|
||||||
PlaceholderFeaturesInfo(
|
|
||||||
modality=p.modality,
|
|
||||||
item_idx=p.item_idx,
|
|
||||||
start_idx=p.start_idx - 1,
|
|
||||||
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
|
|
||||||
) for p in ps
|
|
||||||
]
|
|
||||||
for modality, ps in placeholders.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return token_ids, text, placeholders
|
|
||||||
|
|
||||||
|
|
||||||
class GLMAttention(nn.Module):
|
class GLMAttention(nn.Module):
|
||||||
|
|
||||||
@ -618,7 +584,7 @@ class ChatGLMModel(nn.Module):
|
|||||||
multimodal_embeddings=multimodal_embeddings,
|
multimodal_embeddings=multimodal_embeddings,
|
||||||
placeholder_token_id=[
|
placeholder_token_id=[
|
||||||
self.config.boi_token_id,
|
self.config.boi_token_id,
|
||||||
IMAGE_TOKEN_ID,
|
self.config.pad_token_id,
|
||||||
self.config.eoi_token_id,
|
self.config.eoi_token_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -63,18 +63,6 @@ from .utils import (flatten_bn, is_pp_missing_parameter,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
|
|
||||||
# for the time being, these tags are not considered as special at encoding
|
|
||||||
# time. This may change as VLLMs multimodal API changes in the future.
|
|
||||||
IMG_START = "<img>"
|
|
||||||
IMG_END = "</img>"
|
|
||||||
IMG_PAD = "<imgpad>"
|
|
||||||
# Image context is fixed at 256 for all images
|
|
||||||
MAX_QWEN_IMG_TOKENS = 256
|
|
||||||
# Image normalization params
|
|
||||||
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
|
||||||
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImagePixelInputs(TypedDict):
|
class QwenImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
@ -622,25 +610,6 @@ class QWenModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
|
||||||
"""
|
|
||||||
Build a normalization transform which can be applied to one or
|
|
||||||
more input images from which we want to extract visual features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_size: size of the image to be processed for visual embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable transform for normalizing and resizing one RGB image.
|
|
||||||
"""
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize((image_size, image_size),
|
|
||||||
interpolation=InterpolationMode.BICUBIC),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def _get_tokenizer_without_image_pad(
|
def _get_tokenizer_without_image_pad(
|
||||||
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
||||||
@ -716,16 +685,34 @@ class QWenVLProcessor:
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
if hasattr(self.config, "visual"):
|
if vision_config := getattr(self.config, "visual", None):
|
||||||
self.image_transform = build_normalization_transform(
|
image_size = vision_config["image_size"]
|
||||||
config.visual["image_size"])
|
|
||||||
|
self.image_transform = transforms.Compose([
|
||||||
|
transforms.Resize(
|
||||||
|
(image_size, image_size),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||||
|
std=(0.26862954, 0.26130258, 0.27577711),
|
||||||
|
),
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
self.image_transform = None
|
self.image_transform = None
|
||||||
|
|
||||||
special_tokens: dict[str,
|
@property
|
||||||
int] = tokenizer.special_tokens # type: ignore
|
def image_start_tag(self) -> str:
|
||||||
self.img_start_id = special_tokens[IMG_START]
|
return self.tokenizer.image_start_tag # type: ignore
|
||||||
self.img_end_id = special_tokens[IMG_END]
|
|
||||||
|
@property
|
||||||
|
def image_end_tag(self) -> str:
|
||||||
|
return self.tokenizer.image_end_tag # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_pad_tag(self) -> str:
|
||||||
|
return self.tokenizer.image_pad_tag # type: ignore
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -787,7 +774,14 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
|
|||||||
return {"image": self.get_num_image_tokens()}
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
def get_num_image_tokens(self) -> int:
|
def get_num_image_tokens(self) -> int:
|
||||||
return MAX_QWEN_IMG_TOKENS
|
hf_config = self.get_hf_config()
|
||||||
|
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
image_size = vision_config["image_size"]
|
||||||
|
patch_size = vision_config["patch_size"]
|
||||||
|
grid_length = image_size // patch_size // 2
|
||||||
|
return grid_length * grid_length
|
||||||
|
|
||||||
|
|
||||||
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||||
@ -798,10 +792,12 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
|||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
hf_config = self.info.get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
if not hasattr(hf_config, "visual"):
|
if not (vision_config := getattr(hf_config, "visual", None)):
|
||||||
return ProcessorInputs(prompt_text="", mm_data={})
|
return ProcessorInputs(prompt_text="", mm_data={})
|
||||||
|
|
||||||
vision_config = hf_config.visual
|
processor = self.info.get_hf_processor()
|
||||||
|
img_start = processor.image_start_tag
|
||||||
|
img_end = processor.image_end_tag
|
||||||
|
|
||||||
target_width = target_height = vision_config["image_size"]
|
target_width = target_height = vision_config["image_size"]
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
@ -814,7 +810,7 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ProcessorInputs(
|
return ProcessorInputs(
|
||||||
prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
|
prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
|
||||||
for i in range(1, num_images + 1)),
|
for i in range(1, num_images + 1)),
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
)
|
)
|
||||||
@ -869,13 +865,18 @@ class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
if not hasattr(hf_config, "visual"):
|
||||||
|
return []
|
||||||
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
special_tokens: dict[str,
|
special_tokens: dict[str,
|
||||||
int] = tokenizer.special_tokens # type: ignore
|
int] = tokenizer.special_tokens # type: ignore
|
||||||
|
|
||||||
img_start_id = special_tokens[IMG_START]
|
processor = self.info.get_hf_processor()
|
||||||
img_end_id = special_tokens[IMG_END]
|
img_start_id = special_tokens[processor.image_start_tag]
|
||||||
img_pad_id = special_tokens[IMG_PAD]
|
img_end_id = special_tokens[processor.image_end_tag]
|
||||||
|
img_pad_id = special_tokens[processor.image_pad_tag]
|
||||||
|
|
||||||
num_image_tokens = self.info.get_num_image_tokens()
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
image_tokens = [img_pad_id] * num_image_tokens
|
image_tokens = [img_pad_id] * num_image_tokens
|
||||||
|
|||||||
@ -885,14 +885,10 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
|||||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
max_total_frames = self._get_max_video_frames(seq_len -
|
max_total_frames = self._get_max_video_frames(seq_len -
|
||||||
max_image_tokens)
|
max_image_tokens)
|
||||||
num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
|
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||||
_MAX_FRAMES_PER_VIDEO)
|
_MAX_FRAMES_PER_VIDEO)
|
||||||
|
|
||||||
# Temporary workaround for https://github.com/huggingface/transformers/issues/35412
|
return max(max_frames_per_video, 1)
|
||||||
if num_frames > 1 and num_frames % 2 == 1:
|
|
||||||
num_frames += 1
|
|
||||||
|
|
||||||
return num_frames
|
|
||||||
|
|
||||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user