mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 03:27:03 +08:00
388 lines
14 KiB
Python
388 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import Annotated, Literal, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import BatchFeature, NougatProcessor
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder
|
|
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
|
|
SupportsMultiModal,
|
|
SupportsV0Only)
|
|
from vllm.model_executor.models.swin import SwinModel
|
|
from vllm.model_executor.models.utils import (AutoWeightsLoader,
|
|
_flatten_embeddings, flatten_bn)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
MultiModalKwargsItems)
|
|
from vllm.multimodal.parse import MultiModalDataItems
|
|
from vllm.multimodal.processing import (BaseProcessingInfo,
|
|
EncDecMultiModalProcessor,
|
|
PromptIndexTargets, PromptInsertion,
|
|
PromptUpdate)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
|
|
class MBartDecoderWrapper(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.decoder = MBartDecoder(config,
|
|
cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.decoder")
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.decoder(*args, **kwargs)
|
|
|
|
|
|
class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
self.config = config
|
|
self.model = MBartDecoderWrapper(vllm_config=vllm_config,
|
|
prefix=f"{prefix}.model")
|
|
embed_scale = math.sqrt(
|
|
config.d_model) if config.scale_embedding else 1.0
|
|
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = BartParallelLMHead(self.vocab_size,
|
|
config.d_model,
|
|
embed_scale=embed_scale)
|
|
|
|
self.logits_processor = LogitsProcessor(self.vocab_size,
|
|
config.vocab_size)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
inputs_embeds: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input_ids
|
|
torch.Tensor of *decoder* input token ids.
|
|
positions
|
|
torch.Tensor of *decoder* position indices.
|
|
Returns:
|
|
Output torch.Tensor
|
|
"""
|
|
|
|
return self.model(decoder_input_ids=input_ids,
|
|
decoder_positions=positions,
|
|
encoder_hidden_states=inputs_embeds)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if "final_logits_bias" in name:
|
|
continue
|
|
# if self.config.tie_word_embeddings and "embed_tokens" in name:
|
|
# continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class DonutImagePixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- b: Batch size
|
|
- c: Number of channels (3)
|
|
- h: Height
|
|
- w: Width
|
|
"""
|
|
type: Literal["pixel_values"]
|
|
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
|
|
|
|
|
|
class DonutProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_hf_config(self):
|
|
return self.ctx.get_hf_config()
|
|
|
|
def get_hf_processor(self):
|
|
return self.ctx.get_hf_processor()
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": 1}
|
|
|
|
def get_num_image_tokens(self) -> int:
|
|
return 1
|
|
|
|
|
|
class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]):
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
return ""
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
target_width, target_height = self.info.get_hf_config(
|
|
).encoder.image_size
|
|
|
|
return {
|
|
"image":
|
|
self._get_dummy_images(width=target_width,
|
|
height=target_height,
|
|
num_images=num_images)
|
|
}
|
|
|
|
|
|
class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]):
|
|
|
|
def _hf_processor_applies_updates(
|
|
self,
|
|
prompt_text: str,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
tokenization_kwargs: Mapping[str, object],
|
|
) -> bool:
|
|
return False
|
|
|
|
def create_encoder_prompt(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data: MultiModalDataDict,
|
|
) -> Union[str, list[int]]:
|
|
return prompt
|
|
|
|
def create_decoder_prompt(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data: MultiModalDataDict,
|
|
) -> Union[str, list[int]]:
|
|
return prompt
|
|
|
|
@property
|
|
def pad_dummy_encoder_prompt(self) -> bool:
|
|
return True
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
hf_processor = self.info.get_hf_processor()
|
|
if mm_data:
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt, mm_data, mm_kwargs, tok_kwargs)
|
|
if isinstance(hf_processor, NougatProcessor):
|
|
processed_outputs["input_ids"] = processed_outputs["labels"]
|
|
else:
|
|
tokenizer = hf_processor.tokenizer
|
|
processed_outputs = tokenizer(prompt,
|
|
add_special_tokens=False,
|
|
return_tensors="pt")
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
hf_processor = self.info.get_hf_processor()
|
|
tokenizer = hf_processor.tokenizer
|
|
pad_token_id = tokenizer.pad_token_id
|
|
num_image_tokens = self.info.get_num_image_tokens()
|
|
image_tokens = [pad_token_id] * num_image_tokens
|
|
|
|
return [
|
|
PromptInsertion(
|
|
modality="image",
|
|
target=PromptIndexTargets.start(),
|
|
insertion=image_tokens,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor,
|
|
info=DonutProcessingInfo,
|
|
dummy_inputs=DonutDummyInputsBuilder)
|
|
class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
SupportsV0Only):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
processor_config = vllm_config.model_config.hf_image_processor_config
|
|
|
|
self.config = config
|
|
self.vision_config = config.encoder
|
|
self.processor_config = processor_config
|
|
self.encoder = SwinModel(config=config.encoder)
|
|
|
|
self.decoder = DonutLanguageForConditionalGeneration(
|
|
vllm_config=vllm_config.with_hf_config(config.decoder),
|
|
prefix=f"{prefix}.decoder",
|
|
)
|
|
self.pad_token_id = config.pad_token_id
|
|
|
|
def _parse_and_validate_image_input(self, **kwargs: object):
|
|
pixel_values: Optional[Union[list[list[torch.Tensor]],
|
|
list[torch.Tensor],
|
|
torch.Tensor]] = kwargs.pop(
|
|
"pixel_values", None)
|
|
image_embeds: Optional[Union[list[list[torch.Tensor]],
|
|
list[torch.Tensor],
|
|
torch.Tensor]] = kwargs.pop(
|
|
"image_embeds", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
if pixel_values is not None and image_embeds is not None:
|
|
raise ValueError(
|
|
"Both pixel values and image embeds are provided.")
|
|
|
|
if pixel_values is not None:
|
|
h, w = self.config.encoder.image_size
|
|
return DonutImagePixelInputs(type="pixel_values",
|
|
data=flatten_bn(pixel_values,
|
|
concat=True),
|
|
resolve_bindings={
|
|
"h": h,
|
|
"w": w,
|
|
})
|
|
|
|
if image_embeds is not None:
|
|
raise NotImplementedError
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _process_image_input(
|
|
self, image_input: DonutImagePixelInputs) -> torch.Tensor:
|
|
assert image_input["type"] == "pixel_values"
|
|
pixel_values = image_input["data"]
|
|
dtype = next(self.encoder.parameters()).dtype
|
|
pixel_values = pixel_values.to(dtype)
|
|
return self.encoder(pixel_values)
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.decoder
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return None
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
return vision_embeddings
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings,
|
|
) -> torch.Tensor:
|
|
return _flatten_embeddings(multimodal_embeddings)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
*,
|
|
encoder_input_ids: torch.Tensor,
|
|
encoder_positions: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input_ids
|
|
torch.Tensor of *decoder* input token ids.
|
|
positions
|
|
torch.Tensor of *decoder* position indices.
|
|
encoder_input_ids
|
|
torch.Tensor of *encoder* input token ids.
|
|
encoder_positions
|
|
torch.Tensor of *encoder* position indices
|
|
Returns:
|
|
Output torch.Tensor
|
|
"""
|
|
|
|
inputs_embeds = None
|
|
if encoder_input_ids.numel() > 0:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
|
|
vision_embeddings)
|
|
|
|
hidden_states = self.decoder(input_ids,
|
|
positions,
|
|
inputs_embeds=inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.decoder.compute_logits(hidden_states, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|