[Bugfix] Fix HunyuanOCR cross-image contamination in batch processing (#30344)

Signed-off-by: Lennart Brog <lennart.borg@list-ag.de>
Signed-off-by: Anker <20343812+anker-c2@users.noreply.github.com>
This commit is contained in:
Anker 2025-12-10 19:09:31 +01:00 committed by GitHub
parent 253305d5b2
commit e8e8cd73e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -502,6 +502,7 @@ class HunYuanVisionTransformer(nn.Module):
cu_seqlens: list = [0]
hidden_states = x.to(device=self.device, dtype=self.dtype)
# embeddings = patch_embeds + patch_pos_embed
hidden_states = self.embeddings(hidden_states, grid_thw)
for t, h, w in grid_thw:
@ -515,8 +516,14 @@ class HunYuanVisionTransformer(nn.Module):
hidden_states = hidden_states.reshape(seq_len, -1)
hidden_states = hidden_states.unsqueeze(0)
for layer_num, layer in enumerate(self.layers):
hidden_states = layer(hidden_states)
# build per-image lengths once
split_lengths = [int(h) * int(w) for (_, h, w) in grid_thw]
for layer in self.layers:
# hidden_states: (1, T_total, D)
parts = hidden_states.split(split_lengths, dim=1) # list of (1, L_i, D)
parts = [layer(p) for p in parts]
hidden_states = torch.cat(parts, dim=1)
# adapter
split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()