mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 10:09:07 +08:00
Fix DotsOCR tensor type (#26281)
Signed-off-by: what_in_the_nim <chatcharinsang@gmail.com>
This commit is contained in:
parent
ab5e7d93f4
commit
fc679696f8
@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.patch_embed.patchifier.proj.weight.device
|
return self.patch_embed.patchifier.proj.weight.device
|
||||||
|
|
||||||
def get_pos_ids_by_grid(self, grid_thw):
|
def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
for t, h, w in grid_thw:
|
for t, h, w in grid_thw:
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
|
|
||||||
return pos_ids
|
return pos_ids
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw):
|
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
|
||||||
pos_ids = self.get_pos_ids_by_grid(grid_thw)
|
pos_ids = self.get_pos_ids_by_grid(grid_thw)
|
||||||
pos_ids = torch.cat(pos_ids, dim=0)
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
max_grid_size = grid_thw[:, 1:].max()
|
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
return rotary_pos_emb
|
return rotary_pos_emb
|
||||||
@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
|
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
|
||||||
# Convert grid_thw to tensor (always expecting list format now)
|
# Convert grid_thw to tensor (always expecting list format now)
|
||||||
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
|
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
|
||||||
hidden_states = hidden_states.to(self.dtype)
|
hidden_states = hidden_states.to(self.dtype)
|
||||||
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
||||||
|
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(
|
).cumsum(
|
||||||
@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
|||||||
rope_type="rope_3d",
|
rope_type="rope_3d",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_embeds = self.vision_tower(pixel_values, grid_thw)[
|
image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
|
||||||
:, : self.config.hidden_size
|
:, : self.config.hidden_size
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user