mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-13 23:04:40 +08:00
Fix qwen sage patch too
This commit is contained in:
parent
5dc5a15cc4
commit
e833a3f7df
@ -40,6 +40,7 @@ class BaseLoaderKJ:
|
|||||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
@ -67,7 +68,7 @@ class BaseLoaderKJ:
|
|||||||
joint_key = joint_key.flatten(start_dim=2)
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask)
|
joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user