mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:14:57 +08:00
Fix DotsOCR tensor type (#26281)
Signed-off-by: what_in_the_nim <chatcharinsang@gmail.com>
This commit is contained in:
parent
ab5e7d93f4
commit
fc679696f8
@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
|
||||
def device(self) -> torch.device:
|
||||
return self.patch_embed.patchifier.proj.weight.device
|
||||
|
||||
def get_pos_ids_by_grid(self, grid_thw):
|
||||
def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
|
||||
|
||||
return pos_ids
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
|
||||
pos_ids = self.get_pos_ids_by_grid(grid_thw)
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
|
||||
) -> torch.Tensor:
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# Convert grid_thw to tensor (always expecting list format now)
|
||||
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
|
||||
hidden_states = hidden_states.to(self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
||||
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(
|
||||
@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
rope_type="rope_3d",
|
||||
)
|
||||
else:
|
||||
image_embeds = self.vision_tower(pixel_values, grid_thw)[
|
||||
image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
|
||||
:, : self.config.hidden_size
|
||||
]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user