第2题-大模型Attention模块开发 - 题目详情 - CodeFun2000
python
import sys
import numpy as np
def solve():
data = sys.stdin.read().split()
#print(data)
if not data:
return
n = int(data[0])
m = int(data[1])
h = int(data[2])
X = np.full((n, m), 1, dtype=np.float64)
W1 = np.full((m, h), 0, dtype=np.float64)
for i in range(m):
for j in range(i,h):
W1[i, j] = 1
W2 = W1.copy()
W3 = W1.copy()
#print(X)
#print(W1)
#print(W2)
#print(W3)
Q = X @ W1
K = X @ W2
V = X @ W3
#print(Q)
#print(K)
#print(V)
QKT = Q @ K.T
Y = QKT / np.sqrt(h)
soft_Y = Y.copy()
row_sum = np.zeros(n)
for i in range(n):
row = 0
for j in range(n):
row += Y[i, j]
row_sum[i] = row
for i in range(n):
for j in range(n):
soft_Y[i, j] = Y[i,j] / row_sum[i]
Y_final = soft_Y @ V
#print(Y_final)
#print(np.round(np.sum(Y_final)))
print(int(np.round(np.sum(Y_final))))
if __name__=='__main__':
solve()