mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 05:58:01 +08:00
[Bugfix] Fix HunyuanOCR cross-image contamination in batch processing (#30344)
Signed-off-by: Lennart Brog <lennart.borg@list-ag.de> Signed-off-by: Anker <20343812+anker-c2@users.noreply.github.com>
This commit is contained in:
parent
253305d5b2
commit
e8e8cd73e5
@ -502,6 +502,7 @@ class HunYuanVisionTransformer(nn.Module):
|
||||
cu_seqlens: list = [0]
|
||||
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
# embeddings = patch_embeds + patch_pos_embed
|
||||
hidden_states = self.embeddings(hidden_states, grid_thw)
|
||||
|
||||
for t, h, w in grid_thw:
|
||||
@ -515,8 +516,14 @@ class HunYuanVisionTransformer(nn.Module):
|
||||
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
for layer_num, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
# build per-image lengths once
|
||||
split_lengths = [int(h) * int(w) for (_, h, w) in grid_thw]
|
||||
for layer in self.layers:
|
||||
# hidden_states: (1, T_total, D)
|
||||
parts = hidden_states.split(split_lengths, dim=1) # list of (1, L_i, D)
|
||||
parts = [layer(p) for p in parts]
|
||||
hidden_states = torch.cat(parts, dim=1)
|
||||
|
||||
# adapter
|
||||
split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user