CCF CSP题解:矩阵运算(202305-2)

链接和思路

OJ链接:传送门

本题要求计算1个公式:
( W ⋅ ( Q × K T ) ) × V \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V} (W⋅(Q×KT))×V

其中, Q \mathbf{Q} Q、 K \mathbf{K} K和 V \mathbf{V} V均是 n n n行 d d d列的矩阵, K T \mathbf{K}^{T} KT,表示矩阵 K \mathbf{K} K的转置, × \times ×表示矩阵乘法。 ⋅ \cdot ⋅为点乘,即对应位相乘,记 W ( i ) \mathbf{W}^{(i)} W(i)为向量 W \mathbf{W} W的第 i i i个元素,即将 ( Q × K T ) (\mathbf{Q} \times \mathbf{K}^{T}) (Q×KT)第 i i i行中的每个元素都与 W ( i ) \mathbf{W}^{(i)} W(i)相乘。

本题有2点需要注意,否则只能过70%的样例:

  • 使用int会导致溢出,可使用long long表示数据。
  • 如果按照公式给出的顺序计算,复杂度为 O ( d n 2 ) O(dn^2) O(dn2),注意到 n n n远大于 d d d,因此应该修改运算顺序,优化到 O ( d 2 n ) O(d^2n) O(d2n)。

由于注意到矩阵乘法 A n × m × B m × k \mathbf{A}{n\times m} \times \mathbf{B}{m \times k} An×m×Bm×k的复杂度是 O ( n m k ) O(nmk) O(nmk),因此我们尽可能要让 m m m更小,于是原式的计算顺序可以改变为:
( W ⋅ ( Q × K T ) ) × V = W ⋅ ( Q × ( K T × V ) ) \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V} =\mathbf{W} \cdot \left(\mathbf{Q} \times (\mathbf{K}^{T} \times \mathbf{V} ) \right) (W⋅(Q×KT))×V=W⋅(Q×(KT×V))

调整矩阵乘法顺序在矩阵乘法计算中是十分常见的,如果是一连串任意给定的矩阵相乘,可以用动态规划的方法得到最优的矩阵运算效率。此外,使用行优先的方式比列优先更能充分利用缓存命中率,这也是优化矩阵乘法效率的一个思路,但是由于已经满分,因此在本题中我们没有继续优化。

AC代码

cpp 复制代码
#include <iostream>
#include <vector>

using namespace std;

void print_vector(const vector<vector<long long>> &arr) {
    for (int i = 0; i < arr.size(); i++) {
        for (int j = 0; j < arr[0].size(); j++) {
            if (j != 0)
                cout << " ";
            cout << arr[i][j];
        }
        cout << endl;
    }
}

int main() {
    int n, d;
    cin >> n >> d;
    vector<vector<long long>> q(n), k(n), v(n);
    vector<long long> w(n);
    for (int i = 0; i < n; ++i) {
        q[i].resize(d);
        for (int j = 0; j < d; ++j) {
            cin >> q[i][j];
        }
    }

    for (int i = 0; i < n; ++i) {
        k[i].resize(d);
        for (int j = 0; j < d; ++j) {
            cin >> k[i][j];
        }
    }
    for (int i = 0; i < n; ++i) {
        v[i].resize(d);
        for (int j = 0; j < d; ++j) {
            cin >> v[i][j];
        }
    }

    for (int i = 0; i < n; ++i) {
        cin >> w[i];
    }

    //kv: d x d
    vector<vector<long long>> kv(d);
    for (int i = 0; i < d; ++i) {
        kv[i].resize(d);
    }
    for (int i = 0; i < d; ++i) {
        for (int j = 0; j < d; ++j) {
            for (int l = 0; l < n; ++l) {
                kv[i][j] += k[l][i] * v[l][j];
            }
        }
    }

    //qkv: n x d
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < d; ++j) {
            k[i][j] = 0;
            for (int l = 0; l < d; ++l) {
                k[i][j] += q[i][l] * kv[l][j];
//                printf("k[%d][%d]=%d\n", i, j, k[i][j]);
            }
        }
    }

    // wqkv: n x d
    for (int i = 0; i < n; i++)
        for (int j = 0; j < d; ++j)
            k[i][j] *= w[i];
    print_vector(k);

    return 0;
}
相关推荐
apollowing10 分钟前
启发式算法WebApp实验室:从搜索策略到群体智能的能力进阶(上)
算法·启发式算法·web app
生物信息与育种34 分钟前
黄三文院士领衔植物星球计划(PLANeT)发表Cell
人工智能·深度学习·算法·面试·transformer
aini_lovee41 分钟前
WSN 四大经典无需测距定位算法
算法
人道领域41 分钟前
【LeetCode刷题日记】掌握二叉树遍历:栈实现的三种绝妙方法
算法·leetcode·职场和发展
北冥湖畔的燕雀42 分钟前
深入解析Linux信号处理机制
算法
阿Y加油吧1 小时前
二刷 LeetCode:动态规划经典双题复盘
算法·leetcode·动态规划
上弦月-编程1 小时前
C语言指针超详细教程——从入门到精通(面向初学者)
java·数据结构·算法
莫等闲-1 小时前
代码随想录一刷记录Day44——leetcode1143.最长公共子序列 53. 最大子序和
数据结构·c++·算法·leetcode·动态规划
生成论实验室1 小时前
《事件关系阴阳博弈动力学:识势应势之道》第七篇:社会与情感关系——连接、表达与共鸣
人工智能·算法·架构·交互·创业创新
承渊政道1 小时前
【动态规划算法】(背包问题经典模型与解题套路)
数据结构·c++·学习·算法·leetcode·动态规划·哈希算法