[Bugfix] Clean up MiniCPM-V (#6939)

Co-authored-by: hezhihui <hzh7269@modelbest.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alphi 2024-07-31 22:39:19 +08:00 committed by GitHub
parent 6512937de1
commit 2f4e108f75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 975 additions and 94 deletions

View File

@ -222,9 +222,13 @@ Vision Language Models
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.

View File

@ -418,11 +418,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
input_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
input_embeds)
attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -370,6 +370,7 @@ class MiniCPMModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
@ -463,11 +464,10 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, input_embeds)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,

View File

@ -20,32 +20,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights."""
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers.configuration_utils import PretrainedConfig
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor,
@ -53,12 +55,12 @@ from vllm.multimodal.image import (cached_get_image_processor,
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
"llm.lm_head": "lm_head",
"llm.model": "llm",
}
def get_abs_pos(abs_pos, tgt_size):
def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size):
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim,
grid_size,
cls_token=False,
version=2.0):
def get_2d_sincos_pos_embed(embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0)):
"""
grid_size: int of the grid height and width
return:
@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim,
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if version == 2.0:
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim,
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
grid: Union[int, Tuple[int, int]],
version: Tuple[int, int] = (2, 0)):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0):
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
if version == 2.0:
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
pos: int,
version: Tuple[int, int] = (2, 0)):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0):
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
if version == 2.0:
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
@ -158,19 +164,19 @@ class Resampler(nn.Module):
default_norm_layer = partial(nn.LayerNorm, eps=1e-6)
def __init__(self,
num_queries,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=default_norm_layer,
adaptive=False,
max_size=(70, 70),
version=2.0):
num_queries: int,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: nn.Module = default_norm_layer,
adaptive: bool = False,
max_size: Tuple[int, int] = (70, 70),
version: Tuple[int, int] = (2, 0)):
super().__init__()
self.version = version
if self.version == 2.0:
if self.version == (2, 0):
self.num_queries = grid_size**2
else:
self.num_queries = num_queries
@ -195,7 +201,7 @@ class Resampler(nn.Module):
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
if self.version == 2.0:
if self.version == (2, 0):
self.pos_embed = nn.Parameter(
torch.from_numpy(
get_2d_sincos_pos_embed(
@ -206,14 +212,17 @@ class Resampler(nn.Module):
self.apply(self._init_weights)
def _set_2d_pos_cache(self, max_size, device='cpu'):
def _set_2d_pos_cache(self,
max_size: Tuple[int, int],
device: torch.types.Device = 'cpu'):
pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(self.embed_dim,
max_size,
version=self.version)).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes, device):
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
device: torch.types.Device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
@ -223,7 +232,7 @@ class Resampler(nn.Module):
]
self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m):
def _init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
@ -232,7 +241,9 @@ class Resampler(nn.Module):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_2_5(self, x, tgt_sizes=None):
def forward_2_5(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
@ -278,7 +289,10 @@ class Resampler(nn.Module):
x = x @ self.proj
return x
def forward_2(self, x, tgt_sizes=None, attn_mask=None):
def forward_2(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.adaptive:
pos_embed = torch.Tensor(
get_2d_sincos_pos_embed(self.embed_dim,
@ -302,8 +316,11 @@ class Resampler(nn.Module):
x = x @ self.proj
return x
def forward(self, x, tgt_sizes=None, attn_mask=None):
if self.version == 2.0:
def forward(self,
x: torch.Tensor,
tgt_sizes: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None):
if self.version == (2, 0):
return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask)
else:
return self.forward_2_5(x, tgt_sizes=tgt_sizes)
@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int):
return SequenceData(token_ids)
def dummy_image_for_minicpmv(hf_config):
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision):
def __init__(
self,
config,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
@ -390,30 +407,48 @@ class MiniCPMV(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
self.version = float(self.config.version)
if not hasattr(self.config, "version"):
if self.config.hidden_size == 2304 and self.config.query_num == 64:
self.version = (2, 0)
else:
self.version = (2, 5)
else:
self.version = str(self.config.version).split(".")
self.version = tuple([int(x) for x in self.version])
self.llm = self.init_llm(config, cache_config, quant_config)
self.vpm = self.init_vision_module()
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \
self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \
else self.vpm.embeddings.embed_dim
self.embed_dim = self.llm.config.hidden_size
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def init_llm(self, config, cache_config, quant_config):
if self.version == 2.0:
return MiniCPMForCausalLM(config,
def init_llm(self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
if self.version == (2, 0):
return MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config)
elif self.version == (2, 5):
return LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config)
else:
return LlamaForCausalLM(config,
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
def init_vision_module(self):
if self.version == 2.0:
if self.version == (2, 0):
try:
import timm
except ImportError:
@ -433,16 +468,30 @@ class MiniCPMV(nn.Module, SupportsVision):
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
else:
elif self.version == (2, 5):
from transformers.models.idefics2.modeling_idefics2 import (
Idefics2VisionTransformer)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
else:
from vllm.model_executor.models.na_vit import (
SiglipVisionTransformer)
if self.config._attn_implementation == 'flash_attention_2':
self.config.vision_config._attn_implementation \
= 'flash_attention_2'
else:
# not support sdpa
self.config.vision_config._attn_implementation = 'eager'
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self, embed_dim, vision_dim):
def init_resampler(self, embed_dim: int, vision_dim: int):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float16)
if self.version == 2.0:
if self.version == (2, 0):
resampler = Resampler(grid_size=int(
math.sqrt(self.config.query_num)),
num_queries=None,
@ -463,11 +512,11 @@ class MiniCPMV(nn.Module, SupportsVision):
return resampler
def get_vision_embedding(self,
pixel_values,
patch_attn_mask=None,
tgt_sizes=None,
version=2.0):
if version == 2.0:
pixel_values: List[List[torch.Tensor]],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
version: Tuple[int, int] = (2, 0)):
if version == (2, 0):
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
@ -484,21 +533,32 @@ class MiniCPMV(nn.Module, SupportsVision):
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
else:
elif version == (2, 5):
vision_embedding = self.vpm(
pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
else:
vision_embedding = self.vpm(pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
def get_image_bounds(self, input_ids):
def get_image_bounds(self, input_ids: torch.Tensor):
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
im_start_token_id = tokenizer.im_start_id
im_end_token_id = tokenizer.im_end_id
image_start_tokens = torch.where(input_ids == im_start_token_id)[0]
if not hasattr(tokenizer, "slice_start_id"):
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
else:
start_cond = (input_ids == tokenizer.im_start_id) | (
input_ids == tokenizer.slice_start_id)
end_cond = (input_ids == tokenizer.im_end_id) | (
input_ids == tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(input_ids == im_end_token_id)[0]
valid_image_nums = min(len(image_start_tokens), len(image_end_tokens))
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return []
image_bound = torch.hstack([
@ -508,12 +568,14 @@ class MiniCPMV(nn.Module, SupportsVision):
return image_bound
def get_vision_hidden_states(self, data):
def get_vision_hidden_states(self, data: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
if "vision_hidden_states" not in data:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
vision_hidden_states = []
if self.version == 2.0:
if self.version == (2, 0):
if pixel_values is not None and len(pixel_values) > 0:
vision_hidden_states = self.get_vision_embedding(
pixel_values)
@ -534,17 +596,26 @@ class MiniCPMV(nn.Module, SupportsVision):
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(
0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
if self.version == (2, 5):
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask).last_hidden_state
patch_attention_mask=patch_attn_mask
).last_hidden_state
else:
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] *
tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
vision_hidden_states = self.resampler(
vision_embedding, tgt_sizes)
@ -556,7 +627,8 @@ class MiniCPMV(nn.Module, SupportsVision):
return vision_hidden_states
def get_embedding(self, data):
def get_embedding(self, data: Dict[str, Union[List[torch.Tensor],
torch.Tensor]]):
input_ids = data["input_ids"]
vision_hidden_states = self.get_vision_hidden_states(data)
@ -565,11 +637,11 @@ class MiniCPMV(nn.Module, SupportsVision):
else:
image_bounds = []
if hasattr(self.llm.config, 'scale_emb'):
vlm_embedding = self.llm.model.embed_tokens(
input_ids) * self.llm.config.scale_emb
if hasattr(self.config, 'scale_emb'):
vlm_embedding = self.llm.embed_tokens(
input_ids) * self.config.scale_emb
else:
vlm_embedding = self.llm.model.embed_tokens(input_ids)
vlm_embedding = self.llm.embed_tokens(input_ids)
vision_hidden_states = [
i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i
for i in vision_hidden_states
@ -587,7 +659,9 @@ class MiniCPMV(nn.Module, SupportsVision):
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]))
return vlm_embedding, vision_hidden_states
def process_multimodal_inputs(self, inputs):
def process_multimodal_inputs(self, inputs: Dict[str,
Union[List[torch.Tensor],
torch.Tensor]]):
pixel_values = []
tgt_sizes = []
for b in range(len(inputs["pixel_values"])):
@ -613,7 +687,6 @@ class MiniCPMV(nn.Module, SupportsVision):
"input_ids": input_ids,
"tgt_sizes": kwargs.pop("tgt_sizes", None),
}
inputs = self.process_multimodal_inputs(inputs)
vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
@ -623,19 +696,21 @@ class MiniCPMV(nn.Module, SupportsVision):
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
input_embeds=vlm_embeddings)
inputs_embeds=vlm_embeddings)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
return self.llm.compute_logits(hidden_states, sampling_metadata)
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.llm.sample(logits, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@ -649,9 +724,9 @@ class MiniCPMV(nn.Module, SupportsVision):
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
# if key_to_modify in name:
# name = name.replace(key_to_modify, new_key)
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name

View File

@ -0,0 +1,804 @@
import logging
import math
import os
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (ModelOutput, is_flash_attn_2_available,
replace_return_docstrings)
logger = logging.getLogger("vllm")
# For Siglip: copied from
# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
# Remove hints as there's little possibility to change these code.
class SiglipVisionConfig(PretrainedConfig):
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
os.PathLike],
**kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(
cls,
"model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
"You are using a model of type %s to "
"instantiate a model of type %s. "
"This is not supported for all configurations"
"of models and can yield errors.", config_dict['model_type'],
cls.model_type)
return cls.from_dict(config_dict, **kwargs)
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/siglip-base-patch16-224",
# See all SigLIP models at https://huggingface.co/models?filter=siglip
]
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input # noqa
from flash_attn.bert_padding import index_first_axis, unpad_input
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l_ = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l_ - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
if tensor.dtype in [torch.float16, torch.bfloat16]:
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
og_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
tensor.erfinv_()
tensor = tensor.to(og_dtype)
else:
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
if tensor.dtype == torch.float16:
# The `clamp_` op is not (yet?) defined in float16+cpu
tensor = tensor.to(torch.float32)
tensor.clamp_(min=a, max=b)
tensor = tensor.to(torch.float16)
else:
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0) -> torch.Tensor:
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class SiglipVisionModelOutput(ModelOutput):
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size = pixel_values.size(0)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size,
max_im_w // self.patch_size)
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
1 / self.num_patches_per_side)
position_ids = torch.full(
size=(
batch_size,
max_nb_patches_h * max_nb_patches_w,
),
fill_value=0,
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
boundaries,
right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w,
boundaries,
right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len,
k_v_seq_len):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
"Attention mask should be of size "
f"{(batch_size, 1, q_len, k_v_seq_len)}",
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len,
self.head_dim):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, "
"but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipFlashAttention2(SiglipAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False # Hack to make sure we don't use a causal mask
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(
kv_seq_len, self.layer_idx)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning(
"The input hidden states seems to be "
"silently casted in float32, "
"this might be related to the fact "
"you have upcasted embedding or layer norm layers in float32. "
"We will cast back the input in"
" %s.", target_dtype)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate)
attn_output = attn_output.reshape(bsz, q_len,
self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
def _flash_attention_forward(self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None):
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
(query_states, key_states, value_states, indices_q, cu_seq_lens,
max_seq_lens) = self._upad_input(query_states, key_states,
value_states, attention_mask,
query_length)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
query_length)
else:
attn_output = flash_attn_func(query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
(query_layer, indices_q, cu_seqlens_q,
max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
# with CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self._use_flash_attention_2 = (
config._attn_implementation == "flash_attention_2")
self.self_attn = (SiglipAttention(config)
if not self._use_flash_attention_2 else
SiglipFlashAttention2(config))
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (attn_weights, )
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
config_class = SiglipVisionConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = self.config.hidden_size
nn.init.normal_(module.position_embedding.weight,
std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.normal_(module.q_proj.weight)
nn.init.normal_(module.k_proj.weight)
nn.init.normal_(module.v_proj.weight)
nn.init.normal_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.normal_(module.fc1.weight)
nn.init.normal_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder
# with CLIP->Siglip
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1], )
if output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
if not return_dict:
return tuple(
v for v in [hidden_states, encoder_states, all_attentions]
if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions)
class SiglipVisionTransformer(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
_supports_flash_attn_2 = True
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self._use_flash_attention_2 = (
config._attn_implementation == "flash_attention_2")
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embedding
@replace_return_docstrings(output_type=BaseModelOutputWithPooling,
config_class=SiglipVisionConfig)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
tgt_sizes: Optional[torch.IntTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None \
else self.config.use_return_dict
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s
# (i.e. attending to the whole sequence),
# avoiding passing the attention_mask,
# which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
attention_mask = None
else:
attention_mask = (_prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype)
if not self._use_flash_attention_2 else
patch_attention_mask)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
return (last_hidden_state, None) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=None,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@ -342,7 +342,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,