插头 DP 学习笔记

1. 算法简介

插头 DP 常用于网格图的 DP 中。其核心在于设计如下几部分:

  • 分界线的设置。
  • 插头状态的定义。
  • 合并和延伸插头的转移。

其中,分界线的设置其实规定了 DP 转移的顺序,区别于状压 DP,插头 DP 通常是对每个格子而非整行转移;而插头状态的定义则代表着转移过程中前驱状态对后继的影响。合并插头时常常会进行大量的分类讨论,所以代码难度很大,需要特别注意细节,建议在纸上写好再来写 DP。

为了转移能顺利进行,分界线上应当存储了转移所需的全部信息,这也正是插头的作用。

在此基础上,我们还可以提出一种更广义的插头,也就是在前驱状态钦定一些东西需要在后继状态满足。

2. 例题

其实我感觉插头 DP 拿哈密顿回路计数作为模板有点不合理,本质上哈密顿回路计数是插头 DP 的一个高级应用,要用到哈密顿回路的括号表示来设计状态,这并不是显然的,且没有什么拓展性。个人认为 2.3 才是插头 DP 更简单的一个应用。

2.1 P5056 【模板】插头 DP

考虑用括号表示法刻画"哈密顿回路"的限制。具体而言,如下图所示,以中间绿色的分界线把哈密顿回路切开,然后考虑分界线上方部分的切面。

显然,分界线上方的哈密顿回路分别构成了若干个连通分量,且每个连通分量均为链的形式。我们将链左端的切面转化为左括号 \(\texttt{(}\),将链右端的切面转化为右括号 \(\texttt{)}\),其余切面转化为占位符 \(\texttt{\#}\)。那么图中的哈密顿回路的上半部分就可以用 \(\texttt{(\#(\#\#))\#\#}\) 表示。(一个连通分量列数更小的端点为左括号,列数更大的端点为右括号)

分界线上每个切面的状态就称为插头。因为当转移到 \((x, y)\) 时,分界线只有两个拐点,且都在 \((x, y)\) 处,所以我们称 \((x, y)\) 水平方向的切面为下插头,竖直方向的切面为右插头。

容易发现,哈密顿回路能用括号序列来刻画,关键就在于 任意两个连通分量的端点不会相交。这和括号匹配的性质恰好吻合。

而带占位符的括号序列可以使用三进制数来表示。为了减小常数,我们将三进制数按照四进制数来存储,因为二的正整数次幂作为进制可以使用位运算优化常数。但是有一个新的问题:转化为四进制数后,状态的值域过大。此时可以使用哈希表来存储状态。

设计插头 DP 为:\(\mathrm{dp}_{x, y, S}\) 表示当前考虑到位置 \((x, y)\),分界线的插头状态为 \(S\) 的方案数。

接下来考虑转移。外层肯定是依次枚举位置 \((x, y)\),重点在于内层转移。

首先障碍格的转移在判断合法性后直接继承即可。其余格子在转移的时候需要对右插头、下插头的括号类型进行分类讨论:

  • 无右插头,无下插头:因为每个非障碍格都必须铺,所以需要新建一个连通分量------右插头是左括号,下插头是右括号。
  • 有右插头,无下插头:右插头的括号类型不变,可以选择向右延伸或拐弯向下。
  • 无右插头,有下插头:下插头的括号类型不变,可以选择向下延伸或拐弯向右。
  • 有右插头,有下插头:此时会对两个连通分量进行合并,需要对两个插头的类型分类讨论。
    • 右插头是左括号,下插头是右括号:因为括号匹配不能相交,此时右插头与下插头一定处于同一连通分量中,能够转移的充要条件是 该格子是最后一个可转移的格子
    • 右插头是右括号,下插头是左括号:直接合并,消去两插头的括号,其余插头的括号不变。
    • 右插头是左括号,下插头是左括号:消去两插头的括号后,修改下插头所在连通分量的右括号为左括号。
    • 右插头是右括号,下插头是右括号:消去两插头的括号后,修改右插头所在连通分量的左括号为右括号。

注意,换行的时候需要将所有状态集体左移 \(2\) 位(对应着四进制下左移 \(1\) 位),以匹配新的分界点形态。同时在插头延伸的时候需要判断延伸的方向是否有障碍,否则在进行换行的时候会引发错误。

其中,最后两种情况的分类讨论可以画图进行理解。而后两种的转移则需要用到括号串 \(\pm 1\) 转化的性质去做。

如图所示,这是"右插头是右括号,下插头是右括号"情况的解释。

时间复杂度 \(O(nm3^n)\)。

一些有关代码的技巧:

  • 用哈希表来存储 DP 的状态,然后遍历哈希表中的每一个元素进行刷表法转移。
  • 当遍历到 \((x, y)\) 时,\((x, y)\) 的右插头对应着四进制数的第 \(y-1\) 位,而下插头对应着四进制数的第 \(y\) 位。
cpp 复制代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 15, M = 2e6;

struct HashTable{
    int B = 1999993, h[M], id[M], idx, ne[M];
    ll val[M];

    void clear() {
        memset(h, 0, sizeof(h));
        for(int i = 0; i <= idx; i++) {
            val[i] = id[i] = ne[i] = 0;
        }
        idx = 0;
    }

    void insert(int st, ll v)
    {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] += v;
                return;
            }
        }

        val[++idx] = v;
        id[idx] = st;
        ne[idx] = h[st % B];
        h[st % B] = idx;
    }
} dp[2];

int n, m, a[N][N], ex, ey;

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c; cin >> c;
            if(c == '.') {
                ex = i, ey = j;
                a[i][j] = 1;
            }
        }
    }

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = (i << 1); bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    ll ans = 0;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k]; ll val = dp[pre].val[k];
                int rht = (st >> bit[j - 1]) & 3;
                int dwn = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!rht && !dwn)
                        dp[now].insert(st, val);
                }
                else if(!rht && !dwn) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | (1 << bit[j - 1]) | (2 << bit[j]), val);
                }
                else if(!dwn) {
                    if(a[i][j + 1]) dp[now].insert(st + rht * (bas[j] - bas[j - 1]), val);
                    if(a[i + 1][j]) dp[now].insert(st, val);
                }
                else if(!rht) {
                    if(a[i + 1][j]) dp[now].insert(st + dwn * (bas[j - 1] - bas[j]), val);
                    if(a[i][j + 1]) dp[now].insert(st, val);
                }
                else {

                    if(rht == 1 && dwn == 2) {
                        if(i == ex && j == ey)
                            ans += val;
                    }
                    else if(rht == 2 && dwn == 1) {
                        dp[now].insert(st - rht * bas[j - 1] - dwn * bas[j], val);
                    }
                    else if(rht == 1 && dwn == 1) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j; ; p++) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst -= bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                    else if(rht == 2 && dwn == 2) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j - 1; ; p--) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst += bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                }
            }
        }
    }

    cout << ans;
    return 0;
}

拓展:如果题目中不限制一个哈密顿回路,而是可以有多个哈密顿回路的话,就不用记录每个括号到底是左还是右了,因为不需要区分合并时是否形成回路。可以用二进制状压存储括号的存在性,达到 \(O(nm2^m)\) 的复杂度。

2.2 P2289 [HNOI2004] 邮递员

下文中称行数为 \(n\),列数为 \(m\),与题目中的定义相反。

考虑观察题目中回路的形态:

  • 当 \(n = 1\) 或 \(m = 1\) 时,此时图的形态是一条链,因此从 \((1, 1)\) 走到头之后再走回来即可。方案是唯一的,因此答案为 \(1\)。
  • 当 \(n, m > 1\) 时,由于保证了 \(n\times m\) 是偶数,所以 \(n, m\) 中必然有一个数为偶数。我们一定可以构造出一条长度为 \(n\times m\) 的路径,且此时显然已经达到了答案的下界。
    • 假设 \(n\) 为偶数。我们从 \((1, 1)\) 开始先向下走一格抵达 \((2, 1)\),然后向右走到 \((2, m - 1)\),向下走到 \((3, m-1)\),向左走到 \((3, 1)\),如此往复走 S 形路线,在走到最后一行后沿着最后一列回到第一行即可。
    • 当 \(m\) 为偶数的时候同理,走竖向的 S 形路线即可。

因此当 \(n, m > 1\) 的时候路线的长度(边数)一定得是 \(n\times m\)。又由于三角形斜边长度大于直角边,所以不能斜着走。容易发现这就是让我们数有多少个有向哈密顿回路,可以直接套用 2.1 的做法,最后答案乘个 \(2\)。时间复杂度 \(O(nm3^n)\)。

注意答案是个很大的数,需要使用 __int128 存储。

cpp 复制代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 30, M = 2e6;

void write(i128 x) {
    if(x < 0) {x = -x; putchar('-'); }
    if(x < 10) { putchar('0' + x); return; }
    write(x / 10); putchar('0' + x % 10);
}

struct HashTable{
    int B = 1999993, h[M], id[M], idx, ne[M];
    i128 val[M];

    void clear() {
        memset(h, 0, sizeof(h));
        for(int i = 0; i <= idx; i++) {
            val[i] = id[i] = ne[i] = 0;
        }
        idx = 0;
    }

    void insert(int st, i128 v)
    {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] += v;
                return;
            }
        }

        val[++idx] = v;
        id[idx] = st;
        ne[idx] = h[st % B];
        h[st % B] = idx;
    }
} dp[2];

int n, m, a[N][N], ex, ey;

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> m >> n;
    if(n == 1 || m == 1) {
        cout << 1;
        return 0;
    }

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c = '.';
            if(c == '.') {
                ex = i, ey = j;
                a[i][j] = 1;
            }
        }
    }

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = (i << 1); bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    i128 ans = 0;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k]; i128 val = dp[pre].val[k];
                int rht = (st >> bit[j - 1]) & 3;
                int dwn = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!rht && !dwn)
                        dp[now].insert(st, val);
                }
                else if(!rht && !dwn) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | (1 << bit[j - 1]) | (2 << bit[j]), val);
                }
                else if(!dwn) {
                    if(a[i][j + 1]) dp[now].insert(st + rht * (bas[j] - bas[j - 1]), val);
                    if(a[i + 1][j]) dp[now].insert(st, val);
                }
                else if(!rht) {
                    if(a[i + 1][j]) dp[now].insert(st + dwn * (bas[j - 1] - bas[j]), val);
                    if(a[i][j + 1]) dp[now].insert(st, val);
                }
                else {

                    if(rht == 1 && dwn == 2) {
                        if(i == ex && j == ey)
                            ans += val;
                    }
                    else if(rht == 2 && dwn == 1) {
                        dp[now].insert(st - rht * bas[j - 1] - dwn * bas[j], val);
                    }
                    else if(rht == 1 && dwn == 1) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j; ; p++) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst -= bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                    else if(rht == 2 && dwn == 2) {
                        int sm = 0, vst = st - rht * bas[j - 1] - dwn * bas[j];
                        for(int p = j - 1; ; p--) {
                            int typ = (st >> bit[p]) & 3;
                            if(typ == 1) sm++;
                            else if(typ == 2) sm--;
                            if(sm == 0) {
                                vst += bas[p];
                                break;
                            }
                        }
                        dp[now].insert(vst, val);
                    }
                }
            }
        }
    }

    write(2 * ans);
    return 0;
}

2.3 P4262 [Code+#3] 白金元首与莫斯科

先考虑弱化版:如果障碍物集合已经给定,如何求出合法方案数?

我们依然考虑从左到右、从上到下依次遍历 \((x, y)\),那么分界线就还是只会在 \((x, y)\) 处拐两次的形态。接下来尝试将 \(1\times 2, 2\times 1\) 的方块的信息融入分界线中,显然可以记录两个单位方块之间的分界线来记录某个 \(1\times 2, 2\times 1\) 的方块的存在性。于是分界线的状态可以压为一个 \(m+1\) 位的二进制数,表示每一条分界线上是否有方块覆盖,即是否为一个多米诺骨牌的插头。

状态定义和转移都与普通的插头 DP 无异,这样做的单次时间复杂度为 \(O(nm2^m)\)。

回到原问题中,如果要对每个钦定为障碍物的格子都做一遍复杂度显然会爆炸。而这个问题相当于是一个"撤销单点"的模型,两种常用的思路是缺一分治和记录前后缀信息。这个问题中我们无法快速处理两个不同障碍物集合的 DP 合并,但是却能处理网格前后两部分的合并。

因此考虑"记录前后缀信息"的 trick,分别正着和倒着做一遍插头 DP,记录历史前后缀状态,在钦定撤销点的时候枚举前后缀的状态进行合并即可。这部分的时间复杂度也是 \(O(nm2^m)\)。

cpp 复制代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 18, M = (1 << 18), mod = 1e9 + 7;
int n, m, a[N][N], f[N][N][M], g[N][N][M];

void add(int &x, int val) {
    x += val;
    if(x >= mod) x -= mod;
}

void DO_DP(int dp[N][N][M], int a[N][N]) {
    dp[1][0][0] = 1;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            for(int S = 0; S < (2 << m); S++) {
                int lft = (S >> (j - 1)) & 1;
                int up = (S >> j) & 1;
                int val = dp[i][j - 1][S];
                if(!a[i][j]) {
                    if(!lft && !up)
                        add(dp[i][j][S], val);
                }
                else if(!lft && !up) {
                    if(a[i][j + 1]) add(dp[i][j][S + (1 << j)], val);
                    if(a[i + 1][j]) add(dp[i][j][S + (1 << (j - 1))], val);
                    add(dp[i][j][S], val);
                }
                else if(!lft) {
                    add(dp[i][j][S - (1 << j)], val);
                }
                else if(!up) {
                    add(dp[i][j][S - (1 << (j - 1))], val);
                }
            }
        }

        if(i == n) continue;
        for(int S = 0; S < (1 << m); S++)
            dp[i + 1][0][S << 1] = dp[i][m][S];
    }
}
int Rev[M];


int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            cin >> a[i][j]; a[i][j] ^= 1;
        }
    }

    // DP
    DO_DP(f, a);

    int TMP[N][N]; memcpy(TMP, a, sizeof(a));

    for(int i = 1; i <= n; i++)
        reverse(TMP[i] + 1, TMP[i] + m + 1);
    
    reverse(TMP + 1, TMP + n + 1);

    DO_DP(g, TMP);
    
    // Get Rev
    for(int i = 1; i < (2 << m); i++)
        Rev[i] = (Rev[i >> 1] >> 1) + (i & 1) * (1 << m);
        

    // Combine
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            if(!a[i][j]) {
                cout << 0 << " ";
                continue;
            }

            int ans = 0;
            for(int S = 0; S < (2 << m); S++) {
                if((S >> j) & 1) continue;
                if((S >> (j - 1)) & 1) continue;
                if(!f[i][j - 1][S] || !g[n - i + 1][m - j][Rev[S]]) continue;
                add(ans, ll(f[i][j - 1][S]) * g[n - i + 1][m - j][Rev[S]] % mod);
            }

            cout << ans << " ";
        }
        cout << "\n";
    }

    return 0;
}

2.4 P3886 [JLOI2009] 神秘的生物

题意转化为:将方阵中的若干个点染黑,使得黑点的权值和最大,并且黑点形成一个连通块。

考虑插头 DP。注意本题中分界线设置为方格而不是方格之间的直线,因此分界线的长度即为 \(n\) 的值。并在 DP 的维度中记录分界线上所有黑点所处的连通块编号。

因为列数最多为 \(9\),所以分界线上最多有 \(5\) 个连通块,再加上一个白点的状态,需要用 \(6\) 进制数表示。为了卡常,可以使用 \(8\) 进制数存储。进一步优化,可以使用连通块的最小表示法,即从左到右依次将连通块赋予编号。

转移的时候需要注意:新加入的一个点如果是白点,加入后黑点的连通块个数不能减少,如果减少了就意味着加入了一个孤立的连通分量,这不符合题目的要求。

时间复杂度 \(O(n^3|S|)\)。其中 \(|S|\) 表示本质不同的连通块最小表示法个数。

cpp 复制代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int N = 10, M = 1e6, B = 999983, inf = 0x3f3f3f3f;
int n, a[N][N], bit[N], bas[N], ans = -inf;

struct HashTable {
    int h[M], id[M], val[M], idx, ne[M];
    void clear() {
        memset(h, 0, sizeof(h));
        idx = 0;
    }
    void insert(int st, int v) {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                val[i] = max(val[i], v);
                return;
            }
        }

        ne[++idx] = h[st % B];
        h[st % B] = idx;
        id[idx] = st;
        val[idx] = v;
    }
} dp[2];

struct DSU {
    int fa[N];
    void init() {
        for(int i = 0; i < N; i++) fa[i] = i;
    }
    int findf(int x) {
        if(fa[x] != x) fa[x] = findf(fa[x]);
        return fa[x];
    }
    void combine(int x, int y) {
        int fx = findf(x), fy = findf(y);
        if(fx < fy) swap(fx, fy);
        fa[fx] = fy;
    }
} dsu;

int Trans(int st, int pos) {
    int res = 0; dsu.init();
    int val[N], nid[N], pre[N], cnt = 0;
    memset(pre, -1, sizeof(pre));
    val[0] = 0;
    pre[0] = 0;
    nid[0] = 0;

    for(int i = 1; i <= n; i++) {
        val[i] = (st >> bit[i]) & 7;
        nid[i] = -1;
        if(pre[val[i]] >= 0 && (pos != i || val[i]))
            dsu.combine(i, pre[val[i]]);
        if(pos != i || val[i])
            pre[val[i]] = i;
        if(i + 1 == pos && val[i]) dsu.combine(i, i + 1);
        if(i > 1 && val[i] && val[i - 1]) dsu.combine(i, i - 1);
    }

    for(int i = 1; i <= n; i++) {
        int f = dsu.findf(i);
        if(nid[f] < 0) nid[f] = ++cnt;
        res += bas[i] * nid[f];
    }

    return res;
}

bool check(int st) {
    for(int i = 1; i <= n; i++)
        if(((st / bas[i]) & 7) > 1)
            return 0;
    return 1;
}

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Init
    for(int i = 0; i < N; i++) {
        bit[i] = i * 3;
        bas[i] = (1 << bit[i]);
    }

    // Input
    cin >> n;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            cin >> a[i][j];

    // DP
    int now = 0, pre = 1;

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= n; j++) {
            swap(now, pre);
            dp[now].clear();
            dp[now].insert(bas[j], a[i][j]);

            ans = max(ans, a[i][j]);
            for(int k = 1; k <= dp[pre].idx; k++) {
                // Combine
                int vst = Trans(dp[pre].id[k], j);
                dp[now].insert(vst, dp[pre].val[k] + a[i][j]);
                if(check(vst)) ans = max(ans, dp[pre].val[k] + a[i][j]);

                // Not Combine
                vst = dp[pre].id[k];
                int tmp = ((vst / bas[j]) & 7);
                bool flag = (tmp == 0); 
                for(int p = 1; p <= n; p++) {
                    if(p == j) continue;
                    if(((vst / bas[p]) & 7) == tmp) flag = 1;
                }
                if(flag) dp[now].insert(Trans(vst - bas[j] * ((vst / bas[j]) & 7), -1), dp[pre].val[k]);
            }
        }
    }

    cout << ans;
    return 0;
}

2.5 P3272 [SCOI2011] 地板

数据范围很小,且 L 形的性质很好,考虑插头 DP。先确定分界线的形态,显然直接套用一般的插头 DP 分界线即可。第二步考虑如何在分界线上记录信息,对于 L 形的地板,大致可以分为三类插头:空插头、后面必须收尾的插头、后面必须有拐点的插头,这三类插头分别记为 \(0, 1, 2\)。因此信息转化为一个三进制数,为了优化常数,转化为四进制数加上哈希表存储 DP 状态。

转移的时候对分界线拐点处的右插头、下插头状态进行分类讨论。

  • 当 \((x, y)\) 为障碍时,仅转移右插头、下插头均为空的。
  • 当 \((x, y)\) 非障碍时:
    • 若右插头、下插头均为空,需要新建一块 L 形地板,有三种可能:
      • \((x, y)\) 作为拐点,需要满足 \((x, y + 1), (x + 1, y)\) 不是障碍。
      • \((x, y)\) 作为 L 形端点,并向下延伸。需要满足 \((x + 1, y)\) 不是障碍。
      • \((x, y)\) 作为 L 形端点,并向右延伸。需要满足 \((x, y + 1)\) 不是障碍。
    • 若右插头、下插头均不为空,此时只有两插头均为 \(2\) 才能转移。
    • 若有右插头,无下插头:
      • 让右插头继续延伸下去,需要满足 \((x, y + 1)\) 不是障碍。
      • 让右插头进行拐弯或者收尾,具体是哪个取决于右插头的状态。
    • 若有下插头,无右插头:同理,此处不再赘述。

时间复杂度 \(O(nm2^{\min\{n, m\}})\)。注意一个细节,一开始如果 \(n < m\) 要将矩阵旋转一下再做插头 DP。

cpp 复制代码
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ldb;
typedef __int128 i128;
using pi = pair<int, int>;
const int mod = 20110520;
const int N = 105, M = 1e6, B = 999983;
int n, m, a[N][N];

void add(int &x, int val) {
    x += val;
    if(x >= mod) x -= mod;
}

struct HashTable{
    int h[M], id[M], val[M], ne[M], idx;

    void clear() {
        memset(h, 0, sizeof(h));
        idx = 0;
    }

    void insert(int st, int v) {
        for(int i = h[st % B]; i ; i = ne[i]) {
            if(id[i] == st) {
                add(val[i], v);
                return;
            }
        }

        ne[++idx] = h[st % B];
        h[st % B] = idx;
        id[idx] = st;
        val[idx] = v;
    }
} dp[2];

void Rotate() {
    int b[N][N];
    memset(b, 0, sizeof(b));
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            b[j][n - i + 1] = a[i][j];
        }
    }
    swap(n, m);
    memcpy(a, b, sizeof(b));
}

int bit[N], bas[N];

int main()
{
    //freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    // Input
    cin >> n >> m;
    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= m; j++) {
            char c; cin >> c;
            a[i][j] = (c == '_');
        }
    }
    if(m > n) Rotate();

    // Init
    for(int i = 0; i <= 11; i++) {
        bit[i] = (i << 1);
        bas[i] = (1 << bit[i]);
    }

    // DP
    int now = 0, pre = 1;
    dp[now].insert(0, 1);

    for(int i = 1; i <= n; i++) {
        for(int j = 1; j <= dp[now].idx; j++) dp[now].id[j] <<= 2;

        for(int j = 1; j <= m; j++) {
            swap(now, pre);
            dp[now].clear();

            for(int k = 1; k <= dp[pre].idx; k++) {
                int st = dp[pre].id[k], val = dp[pre].val[k];
                int lft = (st >> bit[j - 1]) & 3;
                int up = (st >> bit[j]) & 3;

                if(!a[i][j]) {
                    if(!lft && !up)
                        dp[now].insert(st, val);
                }
                else if(!lft && !up) {
                    if(a[i + 1][j] && a[i][j + 1])
                        dp[now].insert(st | bas[j - 1] | bas[j], val);
                    if(a[i + 1][j])
                        dp[now].insert(st | (2 * bas[j - 1]), val);
                    if(a[i][j + 1])
                        dp[now].insert(st | (2 * bas[j]), val);
                }
                else if(lft && up) {
                    if(lft == up && lft == 2)
                        dp[now].insert(st ^ (2 * bas[j - 1]) ^ (2 * bas[j]), val);
                }
                else if(!up) {
                    if(a[i][j + 1])
                        dp[now].insert(st ^ (lft * bas[j - 1]) ^ (up * bas[j]) ^ (lft * bas[j]), val);
                    
                    if(lft == 1) {
                        dp[now].insert(st ^ (lft * bas[j - 1]), val);
                    }
                    else if(lft == 2) {
                        if(a[i + 1][j])
                            dp[now].insert(st - bas[j - 1], val);
                    }
                }
                else if(!lft) {
                    if(a[i + 1][j])
                        dp[now].insert(st ^ (lft * bas[j - 1]) ^ (up * bas[j]) ^ (up * bas[j - 1]), val);

                    if(up == 1) {
                        dp[now].insert(st ^ (up * bas[j]), val);
                    }
                    else if(up == 2) {
                        if(a[i][j + 1])
                            dp[now].insert(st - bas[j], val);
                    }
                }
            }
        }
    }

    for(int i = 1; i <= dp[now].idx; i++) {
        if(dp[now].id[i] == 0) {
            cout << dp[now].val[i];
            return 0;
        }
    }
    cout << 0;
    return 0;
}

2.6 P2337 [SCOI2012] 喵星人的入侵

这题代码太恶心了,不怎么想写,并且也都是套路的分讨,这里口胡一下我自己的做法,可能假了或者有点小问题。

下文中假设 \(m < n, m \le 6, n \le 20\)。

首先题意可以转化为:找到一条连通的简单路径,使得可能受到的攻击最大化。因为在求得攻击最大的一条路径后,我们可以将其他路径全部堵住。

先弱化条件,考虑如何数连通简单路径的个数。显然可以插头 DP,只是在哈密顿回路(左括号、右括号、占位符)的基础上加了一个插头:单端插头。也就是起点终点所在连通分量的插头。

这个插头 DP 还有一些可以压缩状态的其他性质:

  • 除了竖直分界线,任意两个相邻的分界线,不能同时作为括号插头(路径)。因为此时会形成一个 \(2\times 2\) 的连通区域,不满足简单路径的限制,可以剪枝。
  • 括号序列不合法的情况。

这些限制还挺强的,粗略估计总体的状态数 \(|S|\) 不会很大,可能也就 \(100\) 左右?

然后尝试加入炮塔的限制。DP 肯定是要新加两维的:当前遍历到的方格、当前使用的炮塔。但是我们仍然无法知道新加一个塔之后的贡献。因此我们考虑记录分界线上方,与分界线紧密相接的方格。方格有三个状态:当做障碍物、当做路径、当做炮塔。一共有 \(m + 1\) 个需要记录的方格,可以用三进制来压缩。

这样 DP 就设计完了,写个哈希表,直接分类讨论转移就好了。时间复杂度 \(O(nmk3^{m+1}|S|)\)。不知道对不对,如有错误欢迎指出。

参考资料

该博客的代码风格在很大程度上参考了 Alex_Wei 的博客。

相关推荐
KS_Fszha1 年前
Luogu P5298 PKUWC2018 Minimax 题解 [ 紫 ] [ 树形 dp ] [ 线段树合并 ] [ 概率 dp ]
数据结构·动态规划 dp