Fix qwen sage patch too

This commit is contained in:
kijai 2025-09-13 12:59:35 +03:00
parent 5dc5a15cc4
commit e833a3f7df

View File

@ -40,6 +40,7 @@ class BaseLoaderKJ:
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@ -67,7 +68,7 @@ class BaseLoaderKJ:
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask)
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]