Add bf16 model
This commit is contained in:
parent
754085eaf5
commit
1cd5409295
@ -127,7 +127,6 @@ class AsymmetricAttention(nn.Module):
|
||||
):
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
#print("x in attn", x.dtype, x.device)
|
||||
|
||||
# Process visual features
|
||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
@ -139,13 +138,6 @@ class AsymmetricAttention(nn.Module):
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
||||
#print("y in attn", y.dtype, y.device)
|
||||
#print(q_y.dtype, q_y.device)
|
||||
#print(self.q_norm_y.weight.dtype, self.q_norm_y.weight.device)
|
||||
# self.q_norm_y.weight = self.q_norm_y.weight.to(q_y.dtype)
|
||||
# self.q_norm_y.bias = self.q_norm_y.bias.to(q_y.dtype)
|
||||
# self.k_norm_y.weight = self.k_norm_y.weight.to(k_y.dtype)
|
||||
# self.k_norm_y.bias = self.k_norm_y.bias.to(k_y.dtype)
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user