Fix DotsOCR tensor type (#26281)

Signed-off-by: what_in_the_nim <chatcharinsang@gmail.com>
This commit is contained in:
Chatcharin Sangbutsarakum 2025-10-06 19:23:43 +07:00 committed by GitHub
parent ab5e7d93f4
commit fc679696f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -617,7 +617,7 @@ class DotsVisionTransformer(nn.Module):
def device(self) -> torch.device: def device(self) -> torch.device:
return self.patch_embed.patchifier.proj.weight.device return self.patch_embed.patchifier.proj.weight.device
def get_pos_ids_by_grid(self, grid_thw): def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
@ -643,10 +643,10 @@ class DotsVisionTransformer(nn.Module):
return pos_ids return pos_ids
def rot_pos_emb(self, grid_thw): def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
pos_ids = self.get_pos_ids_by_grid(grid_thw) pos_ids = self.get_pos_ids_by_grid(grid_thw)
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max() max_grid_size = max(max(h, w) for _, h, w in grid_thw)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb return rotary_pos_emb
@ -667,13 +667,13 @@ class DotsVisionTransformer(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, grid_thw: list[list[int]] self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
) -> torch.Tensor: ) -> torch.Tensor:
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# Convert grid_thw to tensor (always expecting list format now) # Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long) grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype)
hidden_states = self.patch_embed(hidden_states, grid_thw) hidden_states = self.patch_embed(hidden_states, grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum( ).cumsum(
@ -807,7 +807,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
rope_type="rope_3d", rope_type="rope_3d",
) )
else: else:
image_embeds = self.vision_tower(pixel_values, grid_thw)[ image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
:, : self.config.hidden_size :, : self.config.hidden_size
] ]