第2题-大模型Attention模块开发

第2题-大模型Attention模块开发 - 题目详情 - CodeFun2000

python 复制代码
import sys
import numpy as np

def solve():
    data = sys.stdin.read().split()
    #print(data)

    if not data:
        return

    n = int(data[0])
    m = int(data[1])
    h = int(data[2])

    X = np.full((n, m), 1, dtype=np.float64)
    W1 = np.full((m, h), 0, dtype=np.float64)

    for i in range(m):
        for j in range(i,h):
            W1[i, j] = 1

    W2 = W1.copy()
    W3 = W1.copy()

    #print(X)
    #print(W1)
    #print(W2)
    #print(W3)

    Q = X @ W1
    K = X @ W2
    V = X @ W3

    #print(Q)
    #print(K)
    #print(V)

    QKT = Q @ K.T

    Y = QKT / np.sqrt(h)
    soft_Y = Y.copy()
    row_sum = np.zeros(n)
    for i in range(n):
        row = 0
        for j in range(n):
            row += Y[i, j]

        row_sum[i] = row

    for i in range(n):
        for j in range(n):
            soft_Y[i, j] = Y[i,j] / row_sum[i]

    Y_final = soft_Y @ V
    #print(Y_final)

    #print(np.round(np.sum(Y_final)))
    print(int(np.round(np.sum(Y_final))))

if __name__=='__main__':
    solve()
相关推荐
2401_871696522 小时前
JavaScript中代码覆盖率Coverage在精简脚本中的应用
jvm·数据库·python
XiYang-DING2 小时前
【Java EE】多线程(1)
java·python·java-ee
m0_734949792 小时前
Python GUI界面如何实现主题美化_引入ttk模块实现原生外观风格
jvm·数据库·python
光影少年2 小时前
Python+LangGraph学习路线及发展前景
开发语言·人工智能·python·学习
m0_678485452 小时前
如何让导航栏的下落动画效果更慢?
jvm·数据库·python
qq_432703662 小时前
Pandas DataFrame 分组聚合中处理 JSON 列的高效方法
jvm·数据库·python
qq_424098562 小时前
MySQL高负载下查询中断怎么解决_增加系统内存与调整参数
jvm·数据库·python
2301_773553622 小时前
SQL中如何处理多维数据的查询:复合索引与SELECT编写
jvm·数据库·python
大江东去浪淘尽千古风流人物2 小时前
【cuVSLAM】项目解析:一套偏工程实战的 GPU 紧耦合视觉惯性 SLAM
数据库·人工智能·python·机器学习·oracle