【题目全解】ACGO排位赛#13

ACGO排位赛#13 - 题目解析

感谢大家参加本次排位赛!

T1 - 纪元流星雨

题目链接跳转:点击跳转

也没有特别大的难度,手动模拟一下就可以了。

解题步骤

  1. 先计算出这个人一生中第一次看到流星雨的日子:\((E + B) \mod 50\)。
  2. 计算出剩余一生中可以看到流星雨的年份 \(Y\)。
  3. 答案就是 \(\dfrac{Y}{50} + 1\)。

代码实现

本题的 C++ 代码如下:

cpp 复制代码
#include <iostream>
using namespace std;

int solve(int B, int L, int E) {
    int age_at_first_shower = (E + B) % 50;
    if (age_at_first_shower > L) return 0;
    int years_from_first_shower = 
        L - age_at_first_shower;
    return years_from_first_shower / 50 + 1;
}

int main() {
    int T; cin >> T;
    for (int i = 0; i < T; i++) {
        int B, L, E;
        cin >> B >> L >> E;
        cout << solve(B, L, E) << '\n';
    }
    return 0;
}

本题的 Python 代码如下:

cpp 复制代码
T = int(input())
for _ in range(T):
    B, L, E = map(int, input().split(' '))
    before = B + E
    after = L - before
    print(before//50 + after//50 + 1)

T2 - MARCOncatenate

题目链接跳转:点击跳转

根据题目模拟就可以了,没有什么难度。

代码实现

本题的 C++ 代码如下:

cpp 复制代码
#include <iostream>
#include <string>

using namespace std;

string marco = "marco";
string capmarco = "MARCO";

string solve(string S) {
    int i = 0;
    int max = 0;
    for (int i = 1; i <= min(5, int(S.size())); ++i) {
        if (marco.substr(5 - i, i) == S.substr(0, i)) {
            max = i;
        }
    }
    if (max == 0) {
        return S;
    } else {
        return capmarco + S.substr(max, int(S.size()) - max);
    }
}

int main() {
    int T;
    cin >> T;
    for (int i = 0; i < T; i++) {
        string S;
        cin >> S;
        cout << solve(S) << '\n';
    }
    return 0;
}

本题的 Python 代码如下:

cpp 复制代码
def solve(S: str) -> str:    
    marco = "MARCO"
    marco_lower = "marco"
    matching_count = 0

    for i in range(7):
        if S[0:i] == marco_lower[5 - i:]:
            matching_count = i
            
    return marco + S[matching_count:] if matching_count != 0 else S

def main():
    T = int(input())
    for _ in range(T):
        S = input()
        print(solve(S))


if __name__ == '__main__':
    main()

T3 - TNT接力

题目链接跳转:点击跳转

这道题是本次比赛的第三道题,但是许多参赛选手认为这道题是本场比赛最难的题目。

解题思路

  1. 滑动窗口:使用滑动窗口技术,先在前 \(K\) 个方块中计算有多少个空气方块。接着,随着窗口向右滑动,我们移除窗口左端的一个方块,并加入窗口右端的一个方块,更新空气方块的数量。

  2. 最大空气方块数量:我们维护一个变量 mx,表示在当前窗口中最多有多少个空气方块。

  3. 计算答案:每次计算答案时,使用 \(K - \text{最大空气方块数量}\) 来表示最小的 TNT 方块数量,这表示需要多少步才能避免 TNT 的塌陷。如果 \(K \geq N\),表示玩家可以直接跳过整个桥,输出 \(-1\)​。

时间复杂度

每次处理一个序列的时间复杂度是 \(O(N)\),其中 \(N\) 是方块序列的长度。整体复杂度为 \(O(T \times N)\),其中 \(T\) 是测试用例的数量。关于本体,还可以用二分来进一步优化。本文不过多陈述。

代码实现

本题的 C++ 代码如下:

cpp 复制代码
#include <iostream>
#include <string>

using namespace std;

int solve(int N, int K, string S) {
    if (K >= N) return -1;
    ++K;
    int mx = 0, cur = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i] == '-') ++cur;
        if (i - K >= 0 && S[i - K] == '-') --cur;
        mx = max(mx, cur);
    }
    return K - mx;
}

int main() {
    int T;
    cin >> T;
    for (int i = 0; i < T; i++) {
        int N, K;
        cin >> N >> K;
        string S;
        cin >> S;
        cout << solve(N, K, S) << '\n';
    }
    return 0;
}

本题的 Python 代码如下:

python 复制代码
def solve(N: int, K: int, S: str) -> int:
    if K >= N:
        return -1
    mx, cur = 0, 0
    K += 1
    for i in range(N):
        if S[i] == '-':
            cur += 1
        if i - K >= 0 and S[i - K] == '-':
            cur -= 1
        mx = max(mx, cur)
    return K - mx


def main():
    T = int(input())
    for _ in range(T):
        temp = input().split()
        N, K = int(temp[0]), int(temp[1])
        S = input()
        print(solve(N, K, S))


if __name__ == '__main__':
    main()

T4 - 小丑牌

题目链接跳转:点击跳转

解题思路

也是一道模拟题,但需要注意以下几点:

  1. 同一个点数可能会出现五次,那么此时应该输出 High Card(如题意)。
  2. 如果有一个牌型符合上述多条描述,请输出符合描述的牌型中在规则中最后描述的牌型。
  3. 牌的数量不局限于传统的扑克牌,每张牌可以题目中四种花色的任意之一。

代码实现

本题的 C++ 代码如下:

cpp 复制代码
#include <iostream>
#include <algorithm>
using namespace std;

int t;
struct card{
    int rank;
    string suit;
} cards[10];

bool cmp(card a, card b){
    return a.rank < b.rank;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
	cin >> t;
    while(t--){
        for (int i=1; i<=5; i++){
            string rank; string suit;
            cin >> rank >> suit;
            int act;
            if (rank == "J") act = 11;
            else if (rank == "Q") act = 12;
            else if (rank == "K") act = 13;
            else if (rank == "A") act = 14;
            else if (rank == "10") act = 10;
            else act = rank[0] - '0';
            cards[i] = (card){act, suit};
        }
        sort(cards+1, cards+6, cmp);
        int cnt[20] = {};
        bool isSameSuit = true;
        bool isRanked = true;
        int pairs = 0; int greatest = 0;
        for (int i=1; i<=5; i++){
            cnt[cards[i].rank] += 1;
            if (i > 1 && cards[i].suit != cards[i-1].suit)
                isSameSuit = false;
            if (i > 1 && cards[i-1].rank + 1 != cards[i].rank)
				isRanked = false;
        }
        for (int i=1; i<=15; i++){
            if (cnt[i] == 2) pairs++;
            greatest = max(greatest, cnt[i]);
        }
        if (isRanked && isSameSuit){
            if (cards[5].rank == 14) cout << "Royal Flush" << endl;
            else cout << "Straight Flush" << endl;
        } else if (isRanked) cout << "Straight" << endl;
        else if (pairs == 1 && greatest == 3) cout << "Full House" << endl;
        else if (greatest == 4) cout << "Four of a Kind" << endl;
        else if (greatest == 3) cout << "Three of a Kind" << endl;
        else if (pairs == 2) cout << "Two Pairs" << endl;
        else if (pairs == 1) cout << "One Pair" << endl;
        else cout << "High Card" << endl;
    }
    return 0;
}

T5 - Vertex Verse

题目链接跳转:点击跳转

直接模拟情况就可以了,但是细节比较多需要注意一下。

时间复杂度

其中 work 函数会在每次迭代中被调用 \(4\) 次,每次的复杂度是 \(O(E)\)。因此,对于每个输入的 \(q\) 对点,总的时间复杂度是 \(\Theta(4 \times q \times E)\),即\(O(q \times E)\)。但在最坏情况下,图中的边数 \(E\) 可以接近 \(O(n \times m)\),因此总时间复杂度是 \(O(q \times n \times m)\)​。

代码实现

本题的 C++ 代码如下:

cpp 复制代码
#include <iostream>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;

int n, m, q, ei;
int a, b, c, d;
int macw, alex;
struct perEdge{
    int to;
    int next;
} edges[2000005];
int vertex[1000005], cnt;
bool vis[1000005], memo[1000005][5], track;

inline int calc_dir(int x, int y){
    if (x + 1 == y) return 1;
    if (x + m == y) return 2;
    if (x + m + 1 == y) return 3;
    return -1;
}

void add(int v1, int v2){
    ei += 1;
    edges[ei].to = v2;
    edges[ei].next = vertex[v1];
    vertex[v1] = ei;
}

void dfs(int x, int steps, int origin, int dir){
    if (steps > 4) return ;
    if (steps == 4 && x == origin){
        // 说明走一圈可以走到原点。
        // 判断这个环是否已经被之前记录过了。
        if (memo[origin][dir]) return ;
        memo[origin][dir] = 1;
        cnt += 1;
        return ;
    }
    for (int index = vertex[x]; index; index = edges[index].next){
        int to = edges[index].to;
        // 只走编号比自己大的点。
        if (to >= origin || (to == origin && steps + 1 == 4)){
            if (vis[to]) continue;
            vis[to] = 1;
            dfs(to, steps + 1, origin, dir);
            vis[to] = 0;
        }
    }
}

void work(int x){
    for (int index = vertex[x]; index; index = edges[index].next){
        int to = edges[index].to;
        if (to <= x) continue;
        int dir = calc_dir(x, to);
        if (dir != -1){
            vis[to] = 1;
            dfs(to, 1, x, dir);
            vis[to] = 0;
        }
    }
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m >> q;
    for (int i=1; i<=2*q; i++){
        cin >> a >> b >> c >> d;
        int v1 = (a - 1) * m + b;
        int v2 = (c - 1) * m + d;
        add(v1, v2); add(v2, v1);
        // 从v1点开始走四步,看一下能否回到原点
        cnt = 0;
        if (a == c){ 
            // 在同一排
            work(v1); work(v2);
            work(v1 - m); work(v2 - m);
        } else if (b == d){
            // 在同一列
            work(v1); work(v2);
            work(v1 - 1); work(v2 - 1);
        }
        if (i % 2) macw += cnt / 2;
        else alex += cnt / 2;
    }
    cout << macw << " " << alex << endl;
    return 0;
}

另一种写法如下:

cpp 复制代码
#include <iostream>
#include <unordered_map>
#include <map>
#include <vector>
#include <utility>
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
const int N = 2e5 + 10;
map<pi,map<pi,int>> dis;
int mp[2];
bool check(pi a,pi b,pi c,pi d) {
    return dis[a][b] && dis[a][c] && dis[c][d] && dis[b][d];
}
int main() {
    int n,m,q;
    cin >> n >> m >> q;
    q <<= 1;
    for (int i=0;i<q;i++) {
        pi a,b;
        cin >> a.first >> a.second >> b.first >> b.second;
        if (a > b) swap(a,b);
        dis[a][b]++;
        if (dis[a][b] > 1) {
            continue;
        }
        if (a.first == b.first) {
            pi c = make_pair(a.first + 1,a.second);
            pi d = make_pair(b.first + 1,b.second);
            if (check(a,b,c,d)) {
                mp[i%2]++;
            }
            pi e = make_pair(a.first - 1,a.second);
            pi f = make_pair(b.first - 1,b.second);
            if (check(e,f,a,b)) {
                mp[i%2]++;
            }
        }else {
            pi c = make_pair(a.first,a.second + 1);
            pi d = make_pair(b.first,b.second + 1);
            if (check(a,c,b,d)) {
                mp[i%2]++;
            }
            pi e = make_pair(a.first,a.second - 1);
            pi f = make_pair(b.first,b.second - 1);
            if (check(e,a,f,b)) {
                mp[i%2]++;
            }
        }
    }
    cout << mp[0] << " " << mp[1] << endl;
    return 0;
}

T6 - 最优政府大楼选址-2

题目链接跳转:点击跳转

解题思路

本题有好多解决方法,可以使用带权中位数写 \(N = 10^5\)​,为了考虑到朴素的模拟退火算法,本题的数据范围被适当降低了。如果学过模拟退火算法做这道题就非常的简单,把模板的评估函数改成计算意向程度的函数即可。

时间复杂度

模拟退火算法的时间复杂度约为 \(\Theta(k \times N)\)。其中 \(k\) 表示的是在模拟退火过程中计算举例的 F2 函数的调用次数。\(N\) 表示数据规模。

代码实现

本题的 C++ 代码如下(模拟退火):

cpp 复制代码
#include <iostream>
#include <algorithm>
#include <cmath>
#include <random>
using namespace std;

double n;
struct apart{
    double x, y;
} arr[10005];
double ei[10005];
double answ = 1e18, ansx, ansy;

double dis(double x1, double y1, double x2, double y2, int c){
    return (abs(x1 - x2) + abs(y1 - y2)) * ei[c];
}

double F2(double x, double y){
    double ans = 0;
    for (int i=1; i<=n; i++){
        ans += dis(x, y, arr[i].x, arr[i].y, i);
    }
    return ans;
}

void SA(){
    double T = 3000, cold = 0.999, range = 1e-20;
    answ = F2(ansx, ansy);
    while(T > range){
        double ex = ansx + (rand() * 2.0 - RAND_MAX) * T;
        double ey = ansy + (rand() * 2.0 - RAND_MAX) * T;
        double ea = F2(ex, ey);
        double dx = ea - answ;
        if (dx < 0){
            ansx = ex;
            ansy = ey;
            answ = ea;
        } else if (exp(-dx/T) * RAND_MAX > rand()){
            ansx = ex;
            ansy = ey;
        }
        T *= cold;
    }
    return ;
}

signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for (int i=1; i<=n; i++) cin >> ei[i];
    for (int i=1; i<=n; i++){
        cin >> arr[i].x >> arr[i].y;
        ansx += arr[i].x; ansy += arr[i].y;
    }
    ansx /= n; ansy /= n;
    for (int i=1; i<=1; i++) SA();
    printf("%.5lf\n", answ);
    return 0;
}

使用加权中位数算法的 C++ 代码:

cpp 复制代码
#include <bits/stdc++.h>

constexpr double EPS = 1e-6;

double get_med(const std::vector<double> &a, const std::vector<int> &e) {
    int n = a.size();
    double lo = -1e3, hi = 1e3;
    auto f = [&](double x) {
        double sum = 0.0;
        for (int i = 0; i < n; ++i)
            sum += std::abs(x - a[i]) * e[i];
        return sum;
    };
    while (hi - lo > EPS) {
        double mid = (lo + hi) / 2;
        if (f(mid - EPS) > f(mid) and f(mid) > f(mid + EPS))
            lo = mid;
        else
            hi = mid;
    }
    return lo;
}

int main() {
    int n; std::cin >> n;
    std::vector<int> e(n);
    for (auto &x : e) std::cin >> x;
    std::vector<double> x(n), y(n);
    for (int i = 0; i < n; ++i)
        std::cin >> x[i] >> y[i];
    double ax = get_med(x, e), ay = get_med(y, e);
    double res = 0.0;
    for (int i = 0; i < n; ++i)
        res += (std::abs(ax - x[i]) + std::abs(ay - y[i])) * e[i];
    std::cout << std::setprecision(6) << std::fixed << res << '\n';
    return 0;
}

T7 - 乌龟养殖场

题目链接跳转:点击跳转

前置知识:

  1. 了解过基本的动态规划。
  2. 熟练掌握二进制的位运算。

至于为什么放了一道模版题,原因是因为需要凑到八道题目,实在凑不到了,找了一个难度适中的。

题解思路

这是一道典型的状压动态规划 问题。设 \(dp_{i, j}\) 表示遍历到第 \(i\) 行的时候,当前行以 \(j_{(base2)}\) 的形式排列乌龟可以构成的方案数。

对于每一行的方案,我们可以用一个二进制来表示。例如二进制数字 \(10100\),表示有一个横向长度为 \(5\) 的场地中,第 \(1, 3\) 号位置分别放置了一只小乌龟。因此,每一种摆放状态都可以用一个二进制数字来表示。我们也可以通过遍历的方式来遍历出二进制的每一种摆放状态。

首先,我们预处理出横排所有放置乌龟的合法情况。根据题意,两个乌龟不能相邻放置,因此在二进制中,不能有两个 \(1\) 相邻。如何预处理出这种情况呢?我们可以使用位运算的方法:

如果存在一个二进制数字有两个 \(1\) 相邻,那么如果我们对这个数字 \(x\) 进行位运算操作 (x << 1) & x 的结果或 (x >> 1) & x 的结果必定大于等于 \(1\)。我们通过把这种情况排除在外。同时,我们还需要注意有些格子中不能放置乌龟。这一步也可以通过二进制的方法预处理掉,如果网箱在第 \(i\) 一个格子中不能放置乌龟,那么在枚举所有方案数的时候直接忽略掉第 \(i\) 位为 \(1\) 的情况即可。

接下来如何保证上下两行的乌龟不冲突?假如上一行的摆放状态是 \(y\),当前行的摆放状态为 \(j\),如果 i & j 的结果大于等于 \(1\),也可以证明有两个数字 \(1\) 在同一位置上。因此我们也需要把这种情况排除在外。

综上所述,我们可以得出状态转移方程:\(dp_{i, j} = dp_{i, j} + dp_{i-1, k}\)。其中,\(j\) 和 \(k\) 表示所有横排合法的方案。答案就是 \(\mathtt{ANS} = \sum_{j=0}^{2^M-1}{dp_{N, j}}\)。

状态的初始化也很简单,另 \(dp_{0, 0} = 1\)​,表示一只乌龟都不放有一种摆放方案。

时间复杂度

通过观察上述代码,在枚举所有状态和转移状态的时候有三层循环,分别是枚举当前行、枚举当前行的合法摆放情况以及枚举上一行的摆放情况。因此总时间复杂度约为 \(O(n \times 2^M \times 2^M) = O(n \times 2^{M^2}) = O(n \times 4^M)\)。但由于合法的摆放数量远远少于 \(2^M\),因此实际情况下程序运行的速度会快许多。

代码实现

本题的代码实现如下。在输出的时候需要减一,因为不放置也是一种合法情况,根据题目要求需要把这一合法情况排除。

cpp 复制代码
#include <iostream>
using namespace std;

const int MOD = 1e9+7;
int n, m, ans;
int arr[505][505];
// 所有横排合法的情况。
int terrain[505];
int ok[1050], cnt;
int dp[505][1050];

int main(){
    cin >> n >> m;
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            cin >> arr[i][j];
        }
    }
    
    // 预处理非法地形。
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            terrain[i] = (terrain[i] << 1) + !arr[i][j];
        }
    }
    
    // 预处理出所有横排的合法情况。
    for (int i=0; i<(1<<m); i++){
        if (((i<<1)|(i>>1)) & i) continue;
        ok[++cnt] = i;
    }
    dp[0][1] = 1;

    // 枚举。
    for (int i=1; i<=n; i++){
        for (int s1=1; s1<=cnt; s1++){  // 枚举当前行。
            if (ok[s1] & terrain[i]) continue;
            for (int s2=1; s2<=cnt; s2++){  // 枚举上一行。
                if (ok[s2] & terrain[i-1]) continue;
                if (ok[s1] & ok[s2]) continue;
                dp[i][s1] = (dp[i][s1] + dp[i-1][s2]) % MOD;
            }
        }
    }

    // 统计答案。
    int ans = 0;
    for (int i=1; i<=cnt; i++)
        ans = (ans + dp[n][i]) % MOD;
    
    cout << ans - 1 << endl;
    return 0;
}

本题的 Python 代码如下,Python 可以通过本题的所有测试点:

python 复制代码
MOD = int(1e9 + 7)
n, m, ans = 0, 0, 0
arr = [[0] * 505 for _ in range(505)]
terrain = [0] * 505
ok = [0] * 1050
dp = [[0] * 1050 for _ in range(505)]
cnt = 0

def main():
    global n, m, cnt, ans
    
    # 输入 n 和 m
    n, m = map(int, input().split())
    
    # 输入 arr 数组
    for i in range(1, n + 1):
        arr[i][1:m + 1] = map(int, input().split())
    
    # 预处理非法地形
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            terrain[i] = (terrain[i] << 1) + (1 - arr[i][j])
    
    # 预处理出所有横排的合法情况
    for i in range(1 << m):
        if ((i << 1) | (i >> 1)) & i:
            continue
        cnt += 1
        ok[cnt] = i
    
    dp[0][1] = 1
    
    # 枚举
    for i in range(1, n + 1):
        for s1 in range(1, cnt + 1):  # 枚举当前行
            if ok[s1] & terrain[i]:
                continue
            for s2 in range(1, cnt + 1):  # 枚举上一行
                if ok[s2] & terrain[i - 1]:
                    continue
                if ok[s1] & ok[s2]:
                    continue
                dp[i][s1] = (dp[i][s1] + dp[i - 1][s2]) % MOD
    
    # 统计答案
    ans = 0
    for i in range(1, cnt + 1):
        ans = (ans + dp[n][i]) % MOD
    
    print(ans - 1)

if __name__ == "__main__":
    main()

再提供一个暴力解法用于对拍:

cpp 复制代码
#include <iostream>
using namespace std;

const int MOD = 1e9+7;
int n, m, ans;
int arr[505][505];
int dx[] = {0, 1, -1, 0};
int dy[] = {1, 0, 0, -1};

// 深度优先搜索 Brute Force
void dfs(int x, int y){
    if (x > n) {
        ans += 1;
        ans %= MOD;
        return ;
    }
    if (y > m){
        dfs(x+1, 1);
        return ;
    }
    if (arr[x][y] == 0){
        dfs(x, y+1);
        return ;
    }
    // 不放鱼
    dfs(x, y+1);

    // 放鱼
    for (int i=0; i<4; i++){
        int cx = x + dx[i];
        int cy = y + dy[i];
        if (cx < 1 || cy < 1 || cx > n || cy > m) continue;
        if (arr[cx][cy] == 2) return ;
    }
    arr[x][y] = 2;
    dfs(x, y+1);
    arr[x][y] = 1;
    return ;
}

int main(){
    cin >> n >> m;
    for (int i=1; i<=n; i++){
        for (int j=1; j<=m; j++){
            cin >> arr[i][j];
        }
    }
    // dfs 暴力
    dfs(1, 1);
    cout << ans-1 << endl;
    return 0;
}

T8 - 数据中心能耗分析

题目链接跳转:点击跳转

本文仅针对对线段树有一定了解且并且有着较高的程序设计能力的选手。本文的前置知识如下:

  1. 线段树的构造与维护 - 可以参考文章 浅入线段树与区间最值问题
  2. 初中数学 - 完全平方和公式和完全立方和公式。
  3. 取模 - 之如何保证对所有整数取模后的结果为非负整数 - 可以参考本题的 说明/提示 部分。

原本出题的时候我并不知道这道题在许多 OJ 上是有原题的,so sad(下次再改进)。

题目本身应该非常好理解,就是给定一个数组,让你设计一个程序,对程序进行区间求立方和和区间修改的操作。但本题的数据量比较大,\(N, M\) 最大可以达到 \(10^5\),对于每一次修改和查询都是 \(O(N)\) 的时间复杂度,显然用暴力的方法时间复杂度绝对会超时,最高会到 \(O(N^2 * M)\) (大概需要 \(115\) 天的时间才可以搞定一个测试点)。当看到区间查询和维护操作的时候,不难想到用线段树的方法,在线段树中,单次操作的时间复杂度约为 \(O(log_2 N)\),即使当 \(N\) 非常大的时候线段树也可以跑得飞起。

解题思路

不得不说,这是一道比较恶心的线段树区间维护的题目。不光写起来比较费劲,而且维护操作运算量比较大。稍有不慎就会写歪(因此写这道题的时候要集中注意力,稍微一个不起眼的问题就容易爆 \(0\))。

本题的主要难点就是对一个区间内进行批量区间累加的操作。很容易就想到跟完全立方公式的联系:\((a+b)^3 = a^3 +3a^2b + 3ab^2 + b^3\)。区间累加操作也只不过是对区间的所有数字都进行该操作,并对所有操作的结果求和就是该区间进行操作后的立方和。化简可得:

\[\begin{align} \mathtt{ANS} &= \sum_{i=1}^{n}{(a_i+x)^3}\\ &= (a_1+b)^3+(a_2+b)^3+ \cdots + (a_n+b)^3 \\ &= (a_1^3 + 3a_1^2b + 3a_1b^2 + b^3) + (a_2^3 + 3a_2^2b + 3a_2b^2 + b^3) + \cdots + (a_n^3 + 3a_n^2b + 3a_nb^2 + b^3) \\ &= (a_1^3 + a_2^3 + \cdots + a_n^3) + 3b(a_1^2 + a_2^2 + \cdots + a_n^2) + 3b^2(a_1 + a_2 + \cdots + a_n) + nb^3 \\ &= \sum_{i=1}^{n}{a_i^3} + 3b\sum_{i=1}^{n}{a_i^2} + 3b^2\sum_{i=1}^{n}{a_i} + nb^3 \end{align} \]

综上所述,我们只需要用线段树维护三个字段,分别是区间的立方和、区间的平方和以及区间和。在维护平方和的过程中与维护立方和类似,根据完全平方公式 \((a+b)^2 = a^2 + 2ab + b^2\)。经过累加和化简可得:

\[\begin{align} \mathtt{ANS} &= \sum_{i=1}^{n}{(a_i+x)^2}\\ &= (a_1+b)^2 + (a_2+b)^2 + \cdots + (a_n+b)^2 \\ &= (a_1^2 + 2a_1b + b^2) + (a_2^2 + 2a_2b + b^2) + \cdots + (a_n^2 + 2a_nb + b^2) \\ &= (a_1^2 + a_2^2 + \cdots + a_n^2) + 2b(a_1 + a_2 + \cdots + a_n) + nb^2 \\ &= \sum_{i=1}^{n}{a_i^2} + 2b\sum_{i=1}^{n}{a_i} + nb^2 \end{align} \]

以上三个字段可以在构造线段树的时候一并初始化,之后的每次更新直接修改懒标记就可以了。一切都交给 push_down() 函数。在每次区间查询和修改之前都进行懒标记下放操作,对区间进行维护。具体维护操作如下:

cpp 复制代码
// rt 是父节点,l和r是rt的两个子节点,len是rt节点区间的长度。
// 其中,(len - len / 2)是l区间的长度,(len / 2)是r区间的长度。
void push_down(Node &rt, Node &l, Node &r, int len){
    if (rt.tag){
        int num = rt.tag;
		// 维护立方和
        l.s3 += 3 * num * l.s2 + 3 * num * num * l.s1 + (len - len / 2) * num * num * num;
        r.s3 += 3 * num * r.s2 + 3 * num * num * r.s1 + (len / 2) * num * num * num;
		
        //维护平方和
        l.s2 += 2 * num * l.s1 + (len - len / 2) * num * num;
        l.s2 += 2 * num * r.s1 + (len / 2) * num * num;

        // 维护区间总和
        l.s1 += (len - len / 2) * num;
        r.s1 += (len / 2) * num;
		
        // 将标记下放到两个子区间
        l.tag += num;
        r.tag += num;
        rt.tag = 0;  // 清空标记。
    }
    return ;
}

注意事项

  1. 请注意取模,为了保证答案正确性,请在每一步操作的时候都对结果取模。
  2. long long,不然的话只能过前三个测试点(出题人还是挺好的,留了三个小的测试点骗粉)。
  3. 在维护立方和、平方和以及和的时候,请注意维护的顺序。应当先维护立方和,再维护平方和,最后再维护区间总和。
  4. 注意线段树数组的大小,应当为 \(4 \times N\)。
  5. 建议使用读入优化,直接使用 cin 的效率比 std 慢约 \(100\%\)。

时间复杂度

线段树单次查询和修改的复杂度约为 \(O(log_2 N)\),初始化的时间复杂度为 \(\Theta(N)\),因此本代码的整体时间复杂度可以用多项式 \(\Theta(N + M \cdot log_2(N))\) 来表示,整体代码的时间复杂度就为 \(O(M \cdot log_2(N))\)。在极限数据下,程序只需要 \(160ms\) 就可以完成暴力一整年所需的工作。

代码实现

  1. 代码使用了宏定义,方便后期进行调式。
  2. 以下代码与普通的线段树无太大区别,但请着重关注 push_down() 下放操作。
cpp 复制代码
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 1e5 + 5;
const int MOD = 1e9 + 7;
// 宏定义:lc和rc分别表示左儿子和右儿子在数组中的索引。
#define lc root << 1
#define rc root << 1 | 1
#define int long long
int n, m, k, x, y, v;

struct Node {
    // 分别表示总和、平方和与立方和。
    int s1, s2, s3;
    int tag;
} tree[N << 2];

// 合并区间,直接将两个子区间相加即可。
void push_up(int root) {
    tree[root].s1 = (tree[lc].s1 + tree[rc].s1) % MOD;
    tree[root].s2 = (tree[lc].s2 + tree[rc].s2) % MOD;
    tree[root].s3 = (tree[lc].s3 + tree[rc].s3) % MOD;
    return;
}

// 下放操作,确实码量比较大。在借鉴的时候需要仔细点。
void push_down(Node &rt, Node &l, Node &r, int len) {
    if (rt.tag) {
        int num = rt.tag;

        l.s3 = (l.s3 + 3 * num % MOD * l.s2 % MOD + 3 * num % MOD * num % MOD * l.s1 % MOD + (len - len / 2) * num % MOD * num % MOD * num % MOD) % MOD;
        l.s3 = (l.s3 + MOD) % MOD;
        r.s3 = (r.s3 + 3 * num % MOD * r.s2 % MOD + 3 * num % MOD * num % MOD * r.s1 % MOD + (len / 2) * num % MOD * num % MOD * num % MOD) % MOD;
        r.s3 = (r.s3 + MOD) % MOD;

        l.s2 = (l.s2 + 2 * num % MOD * l.s1 % MOD + (len - len / 2) * num % MOD * num % MOD) % MOD;
        l.s2 = (l.s2 + MOD) % MOD;
        r.s2 = (r.s2 + 2 * num % MOD * r.s1 % MOD + (len / 2) * num % MOD * num % MOD) % MOD;
        r.s2 = (r.s2 + MOD) % MOD;

        l.s1 = (l.s1 + (len - len / 2) * num % MOD) % MOD;
        l.s1 = (l.s1 + MOD) % MOD;
        r.s1 = (r.s1 + (len / 2) * num % MOD) % MOD;
        r.s1 = (r.s1 + MOD) % MOD;

        l.tag = (l.tag + num) % MOD;
        l.tag = (l.tag + MOD) % MOD;
        r.tag = (r.tag + num) % MOD;
        r.tag = (r.tag + MOD) % MOD;
        rt.tag = 0;
    }
    return;
}

// 非常简单的造树操作。
void build(int l, int r, int root) {
    if (l == r) {
        int t; cin >> t;
        tree[root].s1 = t % MOD; 
        tree[root].s2 = t * t % MOD;
        tree[root].s3 = t * t % MOD * t % MOD;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, lc);
    build(mid + 1, r, rc);
    push_up(root);
    return;
}

// 更新操作。
void update(int l, int r, int v, int L, int R, int root) {
    // 跟push_down()函数基本类似。
    if (L <= l && r <= R) {
        tree[root].tag = (tree[root].tag + v) % MOD;
        tree[root].tag = (tree[root].tag + MOD) % MOD;
        tree[root].s3 = (tree[root].s3 + 3 * v % MOD * tree[root].s2 % MOD + 3 * v % MOD * v % MOD * tree[root].s1 % MOD + (r - l + 1) * v % MOD * v % MOD * v % MOD) % MOD;
        tree[root].s3 = (tree[root].s3 + MOD) % MOD;
        tree[root].s2 = (tree[root].s2 + 2 * v % MOD * tree[root].s1 % MOD + (r - l + 1) * v % MOD * v % MOD) % MOD;
        tree[root].s2 = (tree[root].s2 + MOD) % MOD;
        tree[root].s1 = (tree[root].s1 + (r - l + 1) * v % MOD) % MOD;
        tree[root].s1 = (tree[root].s1 + MOD) % MOD;
        return;
    }
    push_down(tree[root], tree[lc], tree[rc], r - l + 1);
    int mid = (l + r) >> 1;
    if (L <= mid) update(l, mid, v, L, R, lc);
    if (R >= mid + 1) update(mid + 1, r, v, L, R, rc);
    push_up(root);
}

// 区间查询操作。
int query(int l, int r, int L, int R, int root) {
    if (L <= l && r <= R)
        return (tree[root].s3 + MOD) % MOD;
    int sum = 0;
    push_down(tree[root], tree[lc], tree[rc], r - l + 1);
    int mid = (l + r) >> 1;
    if (L <= mid) sum = (sum + query(l, mid, L, R, lc)) % MOD;
    if (R >= mid + 1) sum = (sum + query(mid + 1, r, L, R, rc)) % MOD;
    return (sum + MOD) % MOD;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    build(1, n, 1);
    while (m--) {
        cin >> k >> x >> y;
        if (k == 1) {
            cin >> v;
            update(1, n, v, x, y, 1);
        } else cout << query(1, n, x, y, 1) << endl;
    }
    return 0;
}

以下是本题的 Python 代码,但由于 Python 常数过大,没有办法通过所有的测试点:

python 复制代码
MOD = 10**9 + 7
N = int(1e5 + 5)

class Node:
    def __init__(self):
        self.s1 = 0
        self.s2 = 0
        self.s3 = 0
        self.tag = 0

tree = [Node() for _ in range(N * 4)]

def push_up(root):
    tree[root].s1 = (tree[root * 2].s1 + tree[root * 2 + 1].s1) % MOD
    tree[root].s2 = (tree[root * 2].s2 + tree[root * 2 + 1].s2) % MOD
    tree[root].s3 = (tree[root * 2].s3 + tree[root * 2 + 1].s3) % MOD

def push_down(root, l, r, length):
    if tree[root].tag != 0:
        num = tree[root].tag % MOD
        left_child = root * 2
        right_child = root * 2 + 1
        left_len = length - length // 2
        right_len = length // 2

        # Left child updates
        tree[left_child].s3 = (tree[left_child].s3 + (3 * num * tree[left_child].s2 % MOD) + (3 * num * num % MOD * tree[left_child].s1 % MOD) + (left_len * num % MOD * num % MOD * num % MOD)) % MOD
        tree[left_child].s2 = (tree[left_child].s2 + (2 * num * tree[left_child].s1 % MOD) + (left_len * num % MOD * num % MOD)) % MOD
        tree[left_child].s1 = (tree[left_child].s1 + (left_len * num % MOD)) % MOD
        tree[left_child].tag = (tree[left_child].tag + num) % MOD

        # Right child updates
        tree[right_child].s3 = (tree[right_child].s3 + (3 * num * tree[right_child].s2 % MOD) + (3 * num * num % MOD * tree[right_child].s1 % MOD) + (right_len * num % MOD * num % MOD * num % MOD)) % MOD
        tree[right_child].s2 = (tree[right_child].s2 + (2 * num * tree[right_child].s1 % MOD) + (right_len * num % MOD * num % MOD)) % MOD
        tree[right_child].s1 = (tree[right_child].s1 + (right_len * num % MOD)) % MOD
        tree[right_child].tag = (tree[right_child].tag + num) % MOD

        tree[root].tag = 0

def build(l, r, root):
    if l == r:
        t = data[l - 1] % MOD
        tree[root].s1 = t
        tree[root].s2 = t * t % MOD
        tree[root].s3 = t * t % MOD * t % MOD
        return
    mid = (l + r) // 2
    build(l, mid, root * 2)
    build(mid + 1, r, root * 2 + 1)
    push_up(root)

def update(l, r, v, L, R, root):
    if L <= l and r <= R:
        num = v % MOD
        length = r - l + 1
        tree[root].tag = (tree[root].tag + num) % MOD
        tree[root].s3 = (tree[root].s3 + (3 * num * tree[root].s2 % MOD) + (3 * num * num % MOD * tree[root].s1 % MOD) + (length * num % MOD * num % MOD * num % MOD)) % MOD
        tree[root].s2 = (tree[root].s2 + (2 * num * tree[root].s1 % MOD) + (length * num % MOD * num % MOD)) % MOD
        tree[root].s1 = (tree[root].s1 + (length * num % MOD)) % MOD
        return
    push_down(root, l, r, r - l + 1)
    mid = (l + r) // 2
    if L <= mid:
        update(l, mid, v, L, R, root * 2)
    if R > mid:
        update(mid + 1, r, v, L, R, root * 2 + 1)
    push_up(root)

def query(l, r, L, R, root):
    if L <= l and r <= R:
        return tree[root].s3 % MOD
    push_down(root, l, r, r - l + 1)
    mid = (l + r) // 2
    res = 0
    if L <= mid:
        res = (res + query(l, mid, L, R, root * 2)) % MOD
    if R > mid:
        res = (res + query(mid + 1, r, L, R, root * 2 + 1)) % MOD
    return res % MOD

if __name__ == '__main__':
    import sys
    sys.setrecursionlimit(1 << 25)
    n, m = map(int, sys.stdin.readline().split())
    data = list(map(int, sys.stdin.readline().split()))
    build(1, n, 1)
    for _ in range(m):
        tmp = sys.stdin.readline().split()
        if not tmp:
            continue
        k = int(tmp[0])
        x = int(tmp[1])
        y = int(tmp[2])
        if k == 1:
            v = int(tmp[3])
            update(1, n, v, x, y, 1)
        else:
            print(query(1, n, x, y, 1))