diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 4821f15..7e7e769 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -40,6 +40,7 @@ class BaseLoaderKJ: encoder_hidden_states_mask: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: seq_txt = encoder_hidden_states.shape[1] @@ -67,7 +68,7 @@ class BaseLoaderKJ: joint_key = joint_key.flatten(start_dim=2) joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask) + joint_hidden_states = attention_sage(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]