UVa12327/LA5705 Xavier is Learning to Count
题目链接
本题是2011年icpc亚洲区域赛上海赛区的C题
题意
已知正整数 p、m 和 m 个不同的正整数 Ai,从这些正整数中选出 p 个,使得它们的和为 n,方案总数设为 f(n)。求出所有 f(n) 不等于0 的 n 和对应的 f(n)。1≤p≤5,1≤m,Ai≤13 000。
分析
生成函数的乘法+FFT,不过需要去重。拿 p=5 的情况举例子,直接用生成函数连乘4次存在以下重复计数:5个数相同;4个数相同;3个数相同且另两个数也相同;3个数相同且另两个数不相同;两个数相同还有另外也相同;仅2个数相同。所以需要把这些重复的情况也计算出来减掉,最后由于连乘形成了排列组合,需要将结果除以 p ! p! p!。
一个坑点
对于规模偏大的数据(比如 m ≥ 8000 , p = 5 m\ge8000,\;p=5 m≥8000,p=5),由于浮点乘法本身的精度损失,再反复做FTT乘法进一步带来损失,最后计算过程完全正确但是输出的结果可能偏小,需要在向上取整时适当弥补一下才能AC。
AC 代码
cpp
#include <iostream>
#include <cstring>
#include <cmath>
using namespace std;
#define M 13001
#define N 1<<16
bool f[M]; int tot, x, kase = 0;
struct complex {
double x, y;
void operator+= (const complex &t) {
x += t.x; y += t.y;
}
complex operator- (const complex &t) const {
return {x - t.x, y - t.y};
}
complex operator* (const complex &t) const {
return {x * t.x - y * t.y, x * t.y + y * t.x};
}
} s1[N], s2[N], s3[N], s4[N], s5[N], s6[N], s7[N];
void fft(complex (&a)[N], int inv) {
for (int i=0, j=0; i<tot; ++i) {
if(j > i) {complex t = a[i]; a[i] = a[j]; a[j] = t;}
int k = tot;
while(j & (k >>= 1)) j &= ~k;
j |= k;
}
for(int step=1; step<tot; step<<=1) {
double alpha = inv*M_PI / step;
for(int k=0; k<step; k++) {
complex wk = {cos(alpha*k), sin(alpha*k)};
for(int Ek=k; Ek<tot; Ek += step<<1) {
int Ok = Ek + step; complex t = wk * a[Ok];
a[Ok] = a[Ek] - t; a[Ek] += t;
}
}
}
}
void solve1() {
for (int i=1; i<=x; ++i) if (f[i]) cout << i << ": " << 1 << endl;
}
void solve2() {
fft(s2, -1);
for (int i=0; i<tot; ++i) {
long long a = s2[i].x / tot + .5;
if (~i&1 && (i>>1)<=x && f[i>>1]) --a;
if (a) cout << i << ": " << a/2 << endl;
}
}
void solve3() {
for (int i=0; i<tot; ++i) s3[i] = s2[i] * s1[i];
for (int i=0; i<tot; ++i) s2[i] = {~i&1 && (i>>1)<=x && f[i>>1] ? 1. : 0., 0.};
fft(s2, 1);
for (int i=0; i<tot; ++i) s2[i] = s2[i] * s1[i];
fft(s2, -1); fft(s3, -1);
for (int i=0; i<tot; ++i) {
long long a = s3[i].x / tot + .5, b = s2[i].x / tot + .5;
if (i%3==0 && i/3<=x && f[i/3]) --a, --b;
a -= 3*b;
if (a) cout << i << ": " << a/6 << endl;
}
}
void solve4() {
for (int i=0; i<tot; ++i) s4[i] = s2[i] * s2[i];
for (int i=0; i<tot; ++i) s3[i] = {i%3==0 && i/3<=x && f[i/3] ? 1. : 0., 0.};
for (int i=0; i<tot; ++i) s5[i] = {~i&1 && (i>>1)<=x && f[i>>1] ? 1. : 0., 0.};
fft(s3, 1); fft(s5, 1);
for (int i=0; i<tot; ++i) s3[i] = s3[i] * s1[i];
for (int i=0; i<tot; ++i) s2[i] = s5[i] * s2[i];
for (int i=0; i<tot; ++i) s5[i] = s5[i] * s5[i];
fft(s2, -1); fft(s3, -1); fft(s4, -1); fft(s5, -1);
for (int i=0; i<tot; ++i) {
long long a = s4[i].x / tot + .5, b = s3[i].x / tot + .5, c = s2[i].x / tot + .5, d = s5[i].x / tot + .5;
if ((i&3)==0 && (i>>2)<=x && f[i>>2]) --a, --b, --c, --d;
a -= 6*c - 8*b - 3*d;
if (a) cout << i << ": " << a/24 << endl;
}
}
void solve5() {
for (int i=0; i<tot; ++i) s4[i] = {(i&3)==0 && (i>>2)<=x && f[i>>2] ? 1. : 0., 0.};
for (int i=0; i<tot; ++i) s3[i] = {i%3==0 && i/3<=x && f[i/3] ? 1. : 0., 0.};
fft(s4, 1); fft(s3, 1);
for (int i=0; i<tot; ++i) s6[i] = s3[i];
for (int i=0; i<tot; ++i) s5[i] = s2[i] * s2[i] * s1[i];
for (int i=0; i<tot; ++i) s4[i] = s4[i] * s1[i];
for (int i=0; i<tot; ++i) s3[i] = s3[i] * s2[i];
for (int i=0; i<tot; ++i) s2[i] = {~i&1 && (i>>1)<=x && f[i>>1] ? 1. : 0., 0.};
fft(s2, 1);
for (int i=0; i<tot; ++i) s7[i] = s2[i] * s2[i] * s1[i];
for (int i=0; i<tot; ++i) s6[i] = s6[i] * s2[i];
for (int i=0; i<tot; ++i) s2[i] = s2[i] * s1[i] * s1[i] * s1[i];
fft(s2, -1); fft(s3, -1); fft(s4, -1); fft(s5, -1); fft(s6, -1); fft(s7, -1);
for (int i=0; i<tot; ++i) {
long long a = s5[i].x / tot + .5, b = s4[i].x / tot + .5, c = s3[i].x / tot + .5,
d = s2[i].x / tot + .5, e = s6[i].x / tot + .5, g = s7[i].x / tot + .5;
if (i%5==0 && i/5<=x && f[i/5]) --a, --b, --c, --d, --e, --g;
g -= b + 2*e; c -= 2*b + e; d -= 3*(b + c + g) + 4*e;
a = (a - 5*b - 10*(c + d + e) - 15*g + 60) / 120;
if (a > 0) cout << i << ": " << a << endl;
}
}
void solve() {
int m, p, v; cin >> m >> p;
memset(f, x = 0, sizeof(f));
while (m--) cin >> v, x = max(x, v), f[v] = true;
for (m=x*p, tot=1; tot <= m; tot<<=1);
for (int i=0; i<tot; ++i) s1[i] = {i<=x && f[i] ? 1. : 0., 0.};
fft(s1, 1);
for (int i=0; i<tot; ++i) s2[i] = s1[i] * s1[i];
cout << "Case #" << ++kase << ':' << endl;
if (p == 1) solve1();
else if (p == 2) solve2();
else if (p == 3) solve3();
else if (p == 4) solve4();
else solve5();
cout << endl;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int t; cin >> t;
while (t--) solve();
return 0;
}