From e8e8cd73e5ddc4b56896e806066c37e9803e54b7 Mon Sep 17 00:00:00 2001 From: Anker <20343812+anker-c2@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:09:31 +0100 Subject: [PATCH] [Bugfix] Fix HunyuanOCR cross-image contamination in batch processing (#30344) Signed-off-by: Lennart Brog Signed-off-by: Anker <20343812+anker-c2@users.noreply.github.com> --- vllm/model_executor/models/hunyuan_vision.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index e5c1be626be07..be084f4ee0f8e 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -502,6 +502,7 @@ class HunYuanVisionTransformer(nn.Module): cu_seqlens: list = [0] hidden_states = x.to(device=self.device, dtype=self.dtype) + # embeddings = patch_embeds + patch_pos_embed hidden_states = self.embeddings(hidden_states, grid_thw) for t, h, w in grid_thw: @@ -515,8 +516,14 @@ class HunYuanVisionTransformer(nn.Module): hidden_states = hidden_states.reshape(seq_len, -1) hidden_states = hidden_states.unsqueeze(0) - for layer_num, layer in enumerate(self.layers): - hidden_states = layer(hidden_states) + + # build per-image lengths once + split_lengths = [int(h) * int(w) for (_, h, w) in grid_thw] + for layer in self.layers: + # hidden_states: (1, T_total, D) + parts = hidden_states.split(split_lengths, dim=1) # list of (1, L_i, D) + parts = [layer(p) for p in parts] + hidden_states = torch.cat(parts, dim=1) # adapter split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()