矩阵乘法与快速幂

矩阵乘法

  • 定义:

    给定矩阵 \(A\) 规模为 \(n\times m\) ,矩阵 \(B\) 规模为 \(m\times p\) ,定义 \(A\times B=C\) ,矩阵 \(C\) 规模为 \(n\times p\) ,满足:

    \[c_{ij}=\sum_{k=1}^ma_{ik}b_{kj} \]

    记住一个口诀:左行右列

  • 注意:

    对于矩阵乘法,满足乘法结合律和乘法分配律,不满足乘法交换律。

    举个例子:

    \(A=\)

    \(B=\)

    那么 \(A\times B=\)

    \(B\times A=\)

    可见 \(A\times B\neq B\times A\) 。

  • 代码实现:

cpp 复制代码
for(int i=1;i<=n;i++)
    for(int j=1;j<=p;j++)
        for(int k=1;k<=m;k++)
            c[i][j]=a[i][k]*b[k][j];

快速幂

  • 定义:

    在 \(O(\log(n))\) 的时间内求出 \(x^n\) 。

  • 原理:

    已知: 若 \(a+b=c\) ,则 \(x^a\times x^b=x^c\) 。

    那么将 \(n\) 转化为二进制,举个例子:

    \[x^{(13)_{10}}=x^{(1101)_2}=x^8\times x^4\times x^1 \]

    因为 \(n\) 有 \(\log(n)\) 个二进制位,所以在知道 \(x^1,x^2,x^3,...,x^{2^{\log(n)}}\) 前提下,即可在 \(O(\log(n))\) 的时间内求出 \(x_n\) 。

    问题转化位如何让求 \(x^1,x^2,x^3,...,x^{2^{\log(n)}}\),而在这个序列中,满足:

    \[x^{2^k}=\begin{cases} x&k=0\\ (x^{2^{k-1}})^2&k\geq 1\\ \end{cases}\]

    于是就同样可以在 \(O(\log(n))\) 的时间内,求出 \(x^1,x^2,x^3,...,x^{2^{\log(n)}}\) 。

    得出结论:计算 \(x^n\) ,只需将 \(n\) 的二进制位为 \(1\) 的整系数幂乘起来即可。

    于是发现,上述所有步骤均可以在 \(O(\log(n))\) 的复杂度上递推实现。

  • 代码实现:

    1. 计算 \(a^b\) 。

      cpp 复制代码
      int qpow(int a,int b)
      {
          int ans=1;
          for(;b;b>>=1)
          {
              if(b&1) ans*=a;
              a*=a;
          }
          return ans;
      }
    2. 计算 \(a^b\bmod P\)

      cpp 复制代码
      int qpow(int a,int b,int P)
      {
          int ans=1;
          for(;b;b>>=1)
          {
              if(b&1) (ans*=a)%=P;
              (a*=a)%=P;
          }
          return ans;
      }
    3. 根据费马小定理,$a^{P-1}≡1\pmod P $ ,前提,\(P\) 为质数。

      那么当 \(P\) 为质数时,可通过 \(a^{P-2}\bmod P\) 求出 \(a^{-1}\bmod P\) ,即 \(a\) 的乘法逆元。

      可用快速幂实现。

矩阵快速幂

  • 定义:

    基于基本的快速幂,将乘法运算扩展为矩阵乘法运算。

    由于矩阵乘法复杂度 \(O(n^3)\) ,快速幂复杂度 \(O(\log(n))\) ,所以矩阵快速幂复杂度为 \(O(n^3\log(n))\) 。

  • 代码实现:

    此处代码仅为样例,矩阵 \(A,B\) 规模均为 \(2\times 2\) ,根据不同题的需要进行更改。

    导入概念,如果矩阵 \(A\) 对角线上元素均为 \(1\) ,其余均为 \(0\) ,则 \(A\times B=B\) 。

    cpp 复制代码
    void qpow(int b)
    {
        memset(ans,0,sizeof(ans));
        for(int i=1;i<=2;i++)
            ans[i][i]=1;//int ans=1;
        for(;b;b>>=1)
        {
            if(b&1)
            {
                for(int i=1;i<=2;i++)   
                    for(int j=1;j<=2;j++)
                        for(int k=1;k<=2;k++)
                            (c[i][j]+=(ans[k][j]*a[i][k])%P)%P;
                for(int i=1;i<=2;i++)
                    for(int j=1;j<=2;j++)   
                        ans[i][j]=c[i][j],c[i][j]=0;
            }//if(b&1) ans*=a;
            for(int i=1;i<=2;i++)
                for(int j=1;j<=2;j++)
                    for(int k=1;k<=2;k++)
                        (c[i][j]+=(a[i][k]*a[k][j])%P)%P;
            for(int i=1;i<=2;i++)
                for(int j=1;j<=2;j++)
                    a[i][j]=c[i][j],c[i][j]=0;//a*=a;
        }
    }
  • 例题:

    • \(Fibonacci\) 第 \(n\) 项

      • 题意:

        \[f_x=\begin{cases} 1&x\in\{1,2\}\\ f_{x-1}+f_{x-2}&x \geq 3\\ \end{cases}\]

        给定一个 \(n\) ,求 \(f_n\) 。

        \(n\leq 2e9\) 。

      • 解法:

        \[f_n=1\times f_{n-1}+1\times f_{n-2} \]

        \[f_{n-1}=1\times f_{n-1}+0\times f_{n-2} \]

        通过矩阵乘法,可将其转化为:

        由此即可通过矩阵快速幂和矩阵乘法求出答案。
        点击查看代码

        cpp 复制代码
        #include<bits/stdc++.h>
        #define int long long 
        #define endl '\n'
        using namespace std;
        const int N=10;
        template<typename Tp> inline void read(Tp&x)
        {
            x=0;register bool z=true;
            register char c=getchar();
            for(;c<'0'||c>'9';c=getchar()) if(c=='-') z=0;
            for(;'0'<=c&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
            x=(z?x:~x+1);
        }
        int n,P,ans[N][N],c[N][N],a[N][N];
        void qpow(int b)
        {
            for(;b;b>>=1)
            {
                if(b&1)
                {
                    for(int i=1;i<=2;i++)   
                        for(int j=1;j<=2;j++)
                            for(int k=1;k<=2;k++)
                                (c[i][j]+=(ans[k][j]*a[i][k])%P)%P;
                    for(int i=1;i<=2;i++)
                        for(int j=1;j<=2;j++)   
                            ans[i][j]=c[i][j],c[i][j]=0;
                }
                for(int i=1;i<=2;i++)
                    for(int j=1;j<=2;j++)
                        for(int k=1;k<=2;k++)
                            (c[i][j]+=(a[i][k]*a[k][j])%P)%P;
                for(int i=1;i<=2;i++)
                    for(int j=1;j<=2;j++)
                        a[i][j]=c[i][j],c[i][j]=0;
            }
        }
        signed main()
        {
            #ifndef ONLINE_JUDGE
            freopen("in.txt","r",stdin);
            freopen("out.txt","w",stdout);
            #endif
            read(n),read(P);
            a[1][1]=1;
            a[2][1]=1;
            a[1][2]=1;
            a[2][2]=0;
            for(int i=1;i<=2;i++)
                ans[i][i]=1;
            qpow(n-2);
            cout<<(ans[1][1]+ans[1][2])%P;//实则为 1*ans[1][1]+1*ans[1][2],1省去,参考矩阵乘法定义。
        }
    • 矩阵加速(数列)

      • 题意:

        与上一道类似的,可理解为稍微变化的 \(Fibonacci\) 第 \(n\) 项。

        \[f_x=\begin{cases} 1&x\in\{1,2,3\}\\ f_{x-1}+f_{x-3}&x \geq 4\\ \end{cases}\]

        给定一个 \(n\) ,求 \(f_n\) 。

        \(n\leq 2e9\) 。

      • 解法:

        也是和上面类似的,可以导出:

        ∵ \(f_n=f_{n-1}+f_{n-3}\)

        \(~~~~f_{n-1}=f_{n-2}+f_{n-4}\)

        \(~~~~f_{n-2}=f_{n-3}+f_{n-5}\)

        ∴ \(f_n=2\times f_{n-3}+1\times f_{n-4}+1\times f_{n-5}\)

        \(~~~~f_{n-1}=1\times f_{n-3}+1\times f_{n-4}+1\times f_{n-5}\)

        \(~~~~f_{n-2}=1\times f_{n-3}+0\times f_{n-4}+1\times f_{n-5}\)

        最后问题在于处理答案。

        这次和上面的不同,上面的每一组就是 \(f_{x}→f_{x+1}\) ,而对于这次而言,\(n\bmod 3\) 的结果不同,统计答案也不同。

        1. \(n\bmod 3=0\) ,对应矩阵中第一行,即 \(ans_{1,1}+ans_{1,2}+ans_{1,3}\) (因为 \(f_1=f_2=f_3=1\) ,所以省去 "\(1\times\)" ,下面不在解释) 。

        2. \(n\bmod 3=2\) ,对应矩阵中第二行,即 \(ans_{2,1}+ams_{2,2}+ans_{2,3}\) 。

        3. \(n\bmod 3=1\) ,对应矩阵中第三行,即 \(ans_{3,1}+ans_{3,2}+ans_{3,3}\) 。

        点击查看代码

        cpp 复制代码
        #include<bits/stdc++.h>
        #define int long long 
        #define endl '\n'
        using namespace std;
        const int N=10,P=1e9+7;
        template<typename Tp> inline void read(Tp&x)
        {
            x=0;register bool z=true;
            register char c=getchar();
            for(;c<'0'||c>'9';c=getchar()) if(c=='-') z=0;
            for(;'0'<=c&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
            x=(z?x:~x+1);
        }
        int t,n,ans[N][N],c[N][N],a[N][N];
        void qpow(int b)
        {
            for(;b;b>>=1)
            {
                if(b&1)
                {
                    for(int i=1;i<=3;i++)   
                        for(int j=1;j<=3;j++)
                            for(int k=1;k<=3;k++)
                                (c[i][j]+=(ans[k][j]*a[i][k])%P)%P;
                    for(int i=1;i<=3;i++)
                        for(int j=1;j<=3;j++)   
                            ans[i][j]=c[i][j],c[i][j]=0;
                }
                for(int i=1;i<=3;i++)
                    for(int j=1;j<=3;j++)
                        for(int k=1;k<=3;k++)
                            (c[i][j]+=(a[i][k]*a[k][j])%P)%P;
                for(int i=1;i<=3;i++)
                    for(int j=1;j<=3;j++)
                        a[i][j]=c[i][j],c[i][j]=0;
            }
        }
        signed main()
        {
            #ifndef ONLINE_JUDGE
            freopen("in.txt","r",stdin);
            freopen("out.txt","w",stdout);
            #endif
            read(t);
            while(t--)
            {
                read(n);
                a[1][1]=2,a[1][2]=1,a[1][3]=1;
                a[2][1]=1,a[2][2]=1,a[2][3]=1;
                a[3][1]=1,a[3][2]=0,a[3][3]=1;
                memset(ans,0,sizeof(ans));
                for(int i=1;i<=3;i++)
                    ans[i][i]=1;
                qpow((n-1)/3);
                if(n%3==0) cout<<(ans[1][1]+ans[1][2]+ans[1][3])%P<<endl;
                else if(n%3==2) cout<<(ans[2][1]+ans[2][2]+ans[2][3])%P<<endl;
                else if(n%3==1) cout<<(ans[3][1]+ans[3][2]+ans[3][3])%P<<endl;
            }
        }

序言

后面(甚至前面)好多只是都要用到矩阵乘法与矩阵快速幂,所以就来补了。

顺便回顾一下快速幂,\(HEOI2024\) 考场上快速幂板子忘了怎么打了,索性重新推了一遍,好在当时推出来了(于是骗到 \(20pts\) )。

题还没有打完,但是知识点算是搞明白了,加深一下记忆继续打题,以防下面的题不知所措。