旅游景点 Tourist Attractions (壮压 DP)题解

简化题意

给定 \(n\) 个点和 \(m\) 条边组成的无向图,按照一定限制要求停留 \(2\sim k+1\) 共 \(k\) 个点(可以经过但不停留),求最短的从 \(1\) 出发到 \(n\) 的路径长。

限制情况如下:

共有 \(q\) 个限制,接下来 \(q\) 行,每行两个数 \(x,y\) ,表示 \(x\) 停留要在 \(y\) 之前。

数据范围 \(2\leq n\leq 2e4,1\leq m\leq 2e5,0\leq k \leq 20,0\leq q\leq\dfrac{k(k+1)}{2}\) 。

不卡内存版

这个版本就比较简单(相对于卡内存的而言)。

下面说思路:

最短路

用 \(dijkstra\) ,因为 \(spfa\) 会被卡。

需要跑 \(1\sim k+1\) 每个点为起点的单源最短路,因为下面 \(DP\) 要用到两点间最短路。(不用怕 \(T\) ,\(k\leq 20\) ,两点间指的是 \(1\sim k+1\) 加上 \(n\) 这 \(k+1\) 个点两点间最短路)。

具体怎么处理两点间最短路,看代码就知道了。
点击查看代码

cpp 复制代码
void dij(int s)
{
    for(int i=1;i<=n;i++) v[i]=0,dis[i]=0x3f3f3f3f;
    priority_queue<pair<int,int>>q;
    q.push(make_pair(0,s));
    dis[s]=0;
    while(!q.empty())
    {
        x=q.top().second;
        q.pop();
        if(v[x]) continue;
        v[x]=1;
        for(int i=head[x];i;i=nxt[i])
        {
            y=to[i],z=w[i];
            if(dis[x]+z<dis[y])
                dis[y]=dis[x]+z,
                q.push(make_pair(-dis[y],y));
        }
    }
    for(int i=1;i<=k+1;i++) d[s][i]=dis[i];
    d[s][0]=dis[n];//一个技巧,用 0 这点代替 n ,可以节省超多空间。
}

壮压 \(DP\)

状态表示

用一个二进制,若停留则该位为 \(1\) ,否则该位为 \(0\) 。

定义一个 \(f[i][j]\) 表示当前停留在第 \(i\) 个点,当前状态为 \(j\) 。

\(eg:\) 若停留 \(2、5、7\) 这三个点, \(k=6\) ,则当前状态为:\(100101\) ,即若第 \(i\) 个点停留,则第 \(i-1\) 位为 \(1\) 。

预处理限制

前面的限制,对于 \(x,y\) ,要求停留 \(x\) 在 \(y\) 之前。

翻译一下,要把他转换成二进制,举个例子,若限制在第 \(4\) 个点停留前,第 \(2,5,7\) 个点必须已经停留,则在第 \(3\) 位( \(4-1\) )为 \(1\) 之前,第 \(1,4,6\) 位必须已经为 \(1\) 。

即在第 \(3\) 位为 \(1\) 之前,当前状态若为 \(100101\) ,则第 \(3\) 位可以变为 \(1\) ,当然,\(100101\) 中那些 \(0\) 改为 \(1\) 也是可以的,如 \(111101\) 也是可以的。
处理如下

cpp 复制代码
while(q--)
    read(x),read(y),
    a[y]|=1<<(x-2);

判断的时候,设当前状态为 \(s\) ,若 \(s\& a[y]=a[y]\) ,则可以在第 \(y-1\) 位为 \(1\) 。

状态转移方程

在这里有两种方法,设当前停留在 \(i\) :

  1. 枚举下一个点,若当前状态符合下一个点的限制要求,则可以转化到下一个点 \(j\) ,那么到下一个点的状态就是将第 \(j-1\) 位变为 \(1\) 。
  2. 枚举上一个点 \(j\) ,则在上一个点 \(j\) 时的状态就为将第 \(i-1\) 位变为 \(0\) ,若改后状态在 \(j-1\) 位为 \(1\) ,则可以通过 \(j\) 的状态更新 \(i\) 的状态。

在这里先用第一种,卡内存版用的第二种。
先看代码

cpp 复制代码
for(int i=0;i<=k+1;i++)
        for(int j=0;j<=maxx;j++)
            f[i][j]=-1;
    f[1][0]=0;
    for(int s=0;s<=maxx;s++)
        for(int i=1;i<=k+1;i++)
            if(f[i][s]!=-1)
                for(int j=2;j<=k+1;j++)
                    if((s&a[j])==a[j])
                    {
                        int p=s|(1<<(j-2));
                        if(f[j][p]>f[i][s]+d[i][j]||f[j][p]==-1)
                            f[j][p]=f[i][s]+d[i][j];
                    }
    for(int i=1;i<=k+1;i++)
        if(f[i][maxx]!=-1)
            ans=min(ans,f[i][maxx]+d[i][0]);

转移方程见代码。

解释一下,现初始化成 \(-1\) ,表示当前 \(f\) 还没有更新过。\(f[1][0]=0\) ,即起点位置肯定是 \(0\) 。

第一城循环枚举当前状态 \(s\),第二层枚举当前点 \(i\),如果现在这个 \(f[i][s]\neq-1\) ,即其已经被更新过,那么就可以通过当前 \(f\) 去推下一个 \(f\) 了。显然的,如果 \(f[i][s]\) 还是 \(-1\) 呢,肯定不能用它继续往下推。

接着枚举下一个点 \(j\) ,则下一个状态就是 \(p=s|(1<<(j-2))\) ,当然的,只有 \((s\& a[j])=a[j]\) ,才可以转移。

如果 \(f[j][p]=-1\) ,也就是说他还没有被更新过,自然要更新他,否则就选择 \(\min(f[j][p],f[i][s]+d[i][j])\) 。

这个 \(d[i][j]\) 就是前面最短路中处理出的两点间的最短路。

最后处理答案,他这 \(k\) 个点都是要停的,也就是所有位都是 \(1\) ,即 \((1<<k)-1\) 。那么就是 \(i=1\sim k\) 中最小的 \(f[i][(1<<k)-1]\) 。
总代码如下

cpp 复制代码
#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
const int N=2e4+10,M=4e5+10,K=25,S=1<<20;
template<typename Tp> inline void read(Tp&x)
{
    x=0;register bool z=1;
    register char c=getchar();
    for(;c<'0'||c>'9';c=getchar()) if(c=='-') z=0;
    for(;'0'<=c&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
    x=(z?x:~x+1);
}
int n,m,k,q,a[K],dis[N],x,y,z,maxx,ans=0x3f3f3f3f;
int head[N],to[M],nxt[M],w[M],tot;
int f[K][S],d[K][K];
bool v[N];
void add(int x,int y,int z)
{
    nxt[++tot]=head[x];
    to[tot]=y;
    w[tot]=z;
    head[x]=tot;
}
void dij(int s)
{
    for(int i=1;i<=n;i++) v[i]=0,dis[i]=0x3f3f3f3f;
    priority_queue<pair<int,int>>q;
    q.push(make_pair(0,s));
    dis[s]=0;
    while(!q.empty())
    {
        x=q.top().second;
        q.pop();
        if(v[x]) continue;
        v[x]=1;
        for(int i=head[x];i;i=nxt[i])
        {
            y=to[i],z=w[i];
            if(dis[x]+z<dis[y])
                dis[y]=dis[x]+z,
                q.push(make_pair(-dis[y],y));
        }
    }
    for(int i=1;i<=k+1;i++) d[s][i]=dis[i];
    d[s][0]=dis[n];
}
signed main()
{
	#ifndef ONLINE_JUDGE
    freopen("in.txt","r",stdin);
    freopen("out.txt","w",stdout);
    #endif
    read(n),read(m),read(k);
    maxx=(1<<k)-1;
    for(int i=1;i<=m;i++)
        read(x),read(y),read(z),
        add(x,y,z),
        add(y,x,z);
    for(int i=1;i<=k+1;i++) dij(i);
    if(!k) cout<<dis[n],exit(0);
    read(q);
    while(q--)
        read(x),read(y),
        a[y]|=1<<(x-2);
    for(int i=0;i<=k+1;i++)
        for(int j=0;j<=maxx;j++)
            f[i][j]=-1;
    f[1][0]=0;
    for(int s=0;s<=maxx;s++)
        for(int i=1;i<=k+1;i++)
            if(f[i][s]!=-1)
                for(int j=2;j<=k+1;j++)
                    if((s&a[j])==a[j])
                    {
                        int p=s|(1<<(j-2));
                        if(f[j][p]>f[i][s]+d[i][j]||f[j][p]==-1)
                            f[j][p]=f[i][s]+d[i][j];
                    }
    for(int i=1;i<=k+1;i++)
        if(f[i][maxx]!=-1)
            ans=min(ans,f[i][maxx]+d[i][0]);
    cout<<ans;
}

卡内存版

好的恶心的来了------(不然他也不至于是道紫)

原来那个直接 \(MLE~69pts\) 了。

好现在让我们思考怎么卡内存。

  1. 联想一下之前棋盘类壮压 \(DP\) ,先用一个 \(dfs\) 处理出所有当前行的所有合法状态。

    这里不放借鉴一下,也是 \(dfs\) 处理状态,其中会存在一些不合法的,那么不合法的那干脆就不叫他有。

    处理的时候用到了 \(lowbit,log\) ,用来搞当前状态中每个 \(1\) 。
    dfs 如下

    cpp 复制代码
    void pre(int s,int x)
    {
        if(x>=k)
        {
            for(int i=s;i;i-=lowbit(i))
            {
                int y=log2(lowbit(i))+2;
                if((a[y]&s)!=a[y])
                    return ;
            }
            int sum=0;
            for(int i=s;i;i-=lowbit(i)) sum++;
            e[sum].push_back(s);
            has[s]=e[sum].size()-1;
            return ;
        }
        pre(s+(1<<x),x+1);
        pre(s,x+1);
    }

    那么在方程中,\(s\) 的那个位置直接变成下表就可以了,即代码中的 \(has\) 数组。

    好的,我们搞出了第一个卡内存的方法。
    点击查看代码

    cpp 复制代码
    #include<bits/stdc++.h>
    #define endl '\n'
    using namespace std;
    const int N=2e4+10,M=4e5+10,K=25,S=7390281;
    template<typename Tp> inline void read(Tp&x)
    {
        x=0;register bool z=true;
        register char c=getchar();
        for(;c<'0'||c>'9';c=getchar()) if(c=='-') z=0;
        for(;'0'<=c&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
        x=(z?x:~x+1);
    }
    bitset<N>v;
    short n,k,q,x,y,z;
    short to[M],w[M]/*,logg[1<<20|1]*/;
    int m,tot,cnt,ans=0x3f3f3f3f;
    int head[N],nxt[M],a[K],dis[N];
    int f[K][S],d[K][K],sit[M],has[1<<20];
    void add(short x,short y,short z)
    {
        nxt[++tot]=head[x];
        to[tot]=y;
        w[tot]=z;
        head[x]=tot;
    }
    void dij(short s)
    {
        for(register short i=1;i<=n;i++) v[i]=false,dis[i]=0x3f3f3f3f;
        priority_queue<pair<int,short>>q;
        q.push(make_pair(0,s));
        dis[s]=0;
        while(!q.empty())
        {
            x=q.top().second;
            q.pop();
            if(v[x]) continue;
            v[x]=true;
            for(register int i=head[x];i;i=nxt[i])
            {
                y=to[i],z=w[i];
                if(dis[x]+z<dis[y])
                    dis[y]=dis[x]+z,
                    q.push(make_pair(-dis[y],y));
            }
        }
        for(register short i=1;i<=k+1;i++) d[s][i]=dis[i];
        d[s][0]=dis[n];
    }
    int lowbit(int x) {return x&-x;}
    void pre(int s,int x)
    {
        if(x>=k)
        {
            sit[++cnt]=s;
            for(int i=sit[cnt];i;i-=lowbit(i))
            {
                int y=log2(lowbit(i))+2;
                if((a[y]&sit[cnt])!=a[y])
                {
                    cnt--;
                    return ;
                }
            }
            return ;
        }
        pre(s+(1<<x),x+1);
        pre(s,x+1);
    }
    // void Log(int x)
    // {
    //     int sum=20;
    //     while(x) logg[x]=sum,sum--,x>>=1;
    // }
    signed main()
    {
        #ifndef ONLINE_JUDGE
        freopen("atr11b.in","r",stdin);
        freopen("out.txt","w",stdout);
        #endif
        //Log((1<<20));
        read(n),read(m),read(k);
        for(register int i=1;i<=m;i++)
            read(x),read(y),read(z),
            add(x,y,z),
            add(y,x,z);
        for(register short i=1;i<=k+1;i++) dij(i);
        if(!k) cout<<dis[n],exit(0);
        read(q);
        while(q--)
            read(x),read(y),
            a[y]|=1<<(x-2);
        pre(0,0);
        for(register short i=0;i<=k+1;i++)
            for(register int j=1;j<=cnt;j++)
                f[i][j]=-1;
        f[1][1]=0;
        sort(sit+1,sit+1+cnt);
        for(int i(1);i<=cnt;++i)
            has[sit[i]]=i;
        for(register int s=1;s<=cnt;s++)
            for(register short i=1;i<=k+1;i++)
                if(f[i][s]!=-1)
                    for(register short j=2;j<=k+1;j++)
                    {
                        register int p=has[sit[s]|(1<<(j-2))];
                        if(!p) continue;
                        if(f[j][p]>f[i][s]+d[i][j]||f[j][p]==-1)
                        f[j][p]=f[i][s]+d[i][j];
                    }
        for(register short i=1;i<=k+1;i++)
            if(f[i][cnt]!=-1)
                ans=min(ans,f[i][cnt]+d[i][0]);
        cout<<ans;
    }

    从代码中这些 \(short\) 中可以看出多么努力在卡内存。

    但坏消息是这么还是 \(MLE~89pts\),原因很简单,\(k=20,q=0\) 直接就噶了,所有状态全部合法,还是那么多内存。

  2. 正解做法是用一个 \(vector\) (当然数组也可以)存,下标是当前已经停留了几个点(即二进制中有几个 \(1\) ),对应值存状态。

    还是一样的,\(f\) 中那一维就不用状态改用下标了。

    这样的话他最多的一个情况也才 \(\text{C}^{10}_{20}=184756\) 个(很容易理解的)。

    而这样显然就需要开三维的 \(f\) 了,\(f[t][i][j]\) ,\(t\) 表示已经停留几个了(即二进制中有几个 \(1\) ),\(i\) 是当前点,\(j\) 是当前状态在 \(s\) 的这个 \(vector\) 中的下标。

    发现还得改进。

    • 开三维的话优化等于没优化甚至负优化,结合上面的代码可以发现第一位可以滚掉,所以滚掉第一位即可了,这样卡内存问题就解决了。

    • 上面的让他不合法状态直接不存在方法也可以接着用啊,虽然不用也能过了。这样的话其实还有个弊端,需要用一个新的数组存下标,和上面一样,但如果不用可以用位运算实现(虽然我不知道怎么实现)。

    好的卡内存问题解决了,又发现一个离谱的问题。

    好吧可能是自身问题,反正前面那个往下一个推的转移方程怎么搞都不对,只好换成枚举上一个的了 \(qwq\) 。

    具体他方法一为什么不行了我也布吉岛,貌似那个必须是排好序的状态,算了不管了。

    那方法二的话也非常简单,改一改就行了,上面也简单说了一下思路,看了代码就能理解了,这里再说几个注意的地方:

    最外面多套一层循环当前停留数 \(t\) ,每次循环都要初始化 \(f[t]\) 全部为 \(0x3f3f3f3f\) 。因为我们滚了嘛,所以要每次都初始化。

    在 \(DP\) 之前可以先把只停留一个点的状态都先处理出来,避免后面麻烦,那后面 \(t\) 的循环从 \(2\) 开始就行了。

    看代码理解吧:
    点击查看代码

    cpp 复制代码
    #include<bits/stdc++.h>
    #define endl '\n'
    using namespace std;
    const int N=2e4+10,M=4e5+10,K=25,S=2e5;
    template<typename Tp> inline void read(Tp&x)
    {
        x=0;register bool z=true;
        register char c=getchar();
        for(;c<'0'||c>'9';c=getchar()) if(c=='-') z=0;
        for(;'0'<=c&&c<='9';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
        x=(z?x:~x+1);
    }
    bitset<N>v;
    short n,k,q,x,y,z;
    short to[M],w[M],logg[1<<20|1];
    int m,tot,ans=0x3f3f3f3f;
    int head[N],nxt[M],a[K],dis[N];
    int f[2][K][S],d[K][K],has[1<<20],start[K];
    vector<int>e[K];
    void add(short x,short y,short z)
    {
        nxt[++tot]=head[x];
        to[tot]=y;
        w[tot]=z;
        head[x]=tot;
    }
    void dij(short s)
    {
        for(register short i=1;i<=n;i++) v[i]=false,dis[i]=0x3f3f3f3f;
        priority_queue<pair<int,short>>q;
        q.push(make_pair(0,s));
        dis[s]=0;
        while(!q.empty())
        {
            x=q.top().second;
            q.pop();
            if(v[x]) continue;
            v[x]=true;
            for(register int i=head[x];i;i=nxt[i])
            {
                y=to[i],z=w[i];
                if(dis[x]+z<dis[y])
                    dis[y]=dis[x]+z,
                    q.push(make_pair(-dis[y],y));
            }
        }
        start[s]=dis[1];
        for(register short i=1;i<=k+1;i++) d[s][i]=dis[i];
        d[s][0]=dis[n];
    }
    int lowbit(int x) {return x&-x;}
    void pre(int s,int x)
    {
        if(x>=k)
        {
            for(int i=s;i;i-=lowbit(i))
            {
                int y=logg[lowbit(i)]+2;
                if((a[y]&s)!=a[y])
                    return ;
            }
            int sum=0;
            for(int i=s;i;i-=lowbit(i)) sum++;
            e[sum].push_back(s);
            has[s]=e[sum].size()-1;
            return ;
        }
        pre(s+(1<<x),x+1);
        pre(s,x+1);
    }
    void Log(int x)
    {
        int sum=20;
        for(;x;x>>=1) logg[x]=sum,sum--;
    }
    signed main()
    {
        #ifndef ONLINE_JUDGE
        freopen("in.txt","r",stdin);
        freopen("out.txt","w",stdout);
        #endif
        Log(1<<20);
        read(n),read(m),read(k);
        for(int i=0;i<=(1<<k)-1;i++) has[i]=0x3f3f3f3f;
        for(register int i=1;i<=m;i++)
            read(x),read(y),read(z),
            add(x,y,z),
            add(y,x,z);
        for(register short i=1;i<=k+1;i++) dij(i);
        read(q);
        while(q--)
            read(x),read(y),
            a[y]|=1<<(x-2);
        if(!k) cout<<dis[n],exit(0);
        pre(0,0);
        for(int i=2;i<=k+1;i++)
            for(int j=0;j<=e[i].size();j++)
                f[1][i][j]=f[0][i][j]=0x3f3f3f3f;
        for(int i=2;i<=k+1;i++)
            if(!a[i]&&has[1<<(i-2)]!=0x3f3f3f3f)
                f[1][i][has[1<<(i-2)]]=start[i];
        for(int t=2;t<=k;t++)
        {
            for(int i=0;i<e[t].size();i++) 
                for(int j=2;j<=k+1;j++)
                    f[t&1][j][i]=0x3f3f3f3f;
            for(int s=0;s<e[t].size();s++)
                for(int i=2;i<=k+1;i++)
                    if((e[t][s]&(1<<(i-2)))&&((e[t][s]&(~(1<<(i-2))))&a[i])==a[i])
                        for(int j=2;j<=k+1;j++)
                            if(i!=j&&(e[t][s]&(1<<(j-2)))&&has[e[t][s]&(~(1<<(i-2)))]!=0x3f3f3f3f)
                                f[t&1][i][s]=min(f[t&1][i][s],f[(t-1)&1][j][has[e[t][s]&(~(1<<(i-2)))]]+d[j][i]);
        }
        for(register short i=2;i<=k+1;i++)
                ans=min(ans,f[k&1][i][0]+d[i][0]);
        cout<<ans;
    }

    跑的相当快的,虽然不是最优解但也差不多,别的 \(oj\) 都是 \(3000ms||5000ms\) 的,\(luogu\) 上是 \(1000ms\) ,这样一部分代码就过不去了,但是我的可以过。

  3. 还有一个最最抽象但是能过的方法,将 \(1<<20\) 存成 \(1<<19\) ,用的时候再改回来,竟然能卡过,具体怎么做问 教主

悲惨的做题记录

首先光不卡内存版就交了 \(50\) 多遍才对,详情请见 \(\Large{深痛教训}\)

然后加上卡内存版的一共是交了 \(80\) 多遍才 \(A\) ,因为方法一一直调不对,加上一开始卡内存方式不对。

如图:

\(\Huge{qwq}\)

心想既然这么不容易(明明是粗心大意)做出来的题,就打篇题解吧。

最后有没有哪位大佬能帮忙解释一下为啥方法一在卡内存版里过不去。