Fix DotsOCR tensor type (#26281)

Signed-off-by: what_in_the_nim <chatcharinsang@gmail.com>
This commit is contained in:
Chatcharin Sangbutsarakum 2025-10-06 19:23:43 +07:00 committed by GitHub
parent ab5e7d93f4
commit fc679696f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
def device(self) -> torch.device:
return self.patch_embed.patchifier.proj.weight.device
def get_pos_ids_by_grid(self, grid_thw):
def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
return pos_ids
def rot_pos_emb(self, grid_thw):
def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
pos_ids = self.get_pos_ids_by_grid(grid_thw)
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
) -> torch.Tensor:
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
hidden_states = hidden_states.to(self.dtype)
hidden_states = self.patch_embed(hidden_states, grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(
@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
rope_type="rope_3d",
)
else:
image_embeds = self.vision_tower(pixel_values, grid_thw)[
image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
:, : self.config.hidden_size
]