幼儿园里有 N 个小朋友,老师现在想要给这些小朋友们分配糖果,要求每个小朋友都要分到糖果。
但是小朋友们也有嫉妒心,总是会提出一些要求,比如小明不希望小红分到的糖果比他的多,于是在分配糖果的时候, 老师需要满足小朋友们的 K 个要求。
幼儿园的糖果总是有限的,老师想知道他至少需要准备多少个糖果,才能使得每个小朋友都能够分到糖果,并且满足小朋友们所有的要求。
输入格式
输入的第一行是两个整数 N,K。
接下来 K 行,表示分配糖果时需要满足的关系,每行 3 个数字 X,A,B。
- 如果 X=1.表示第 A 个小朋友分到的糖果必须和第 B 个小朋友分到的糖果一样多。
- 如果 X=2,表示第 A 个小朋友分到的糖果必须少于第 B 个小朋友分到的糖果。
- 如果 X=3,表示第 A 个小朋友分到的糖果必须不少于第 B 个小朋友分到的糖果。
- 如果 X=4,表示第 A 个小朋友分到的糖果必须多于第 B 个小朋友分到的糖果。
- 如果 X=5,表示第 A 个小朋友分到的糖果必须不多于第 B 个小朋友分到的糖果。
小朋友编号从 1 到 N。
输出格式
输出一行,表示老师至少需要准备的糖果数,如果不能满足小朋友们的所有要求,就输出 −1。
数据范围
1≤N≤105
1≤K≤105
1≤X≤5
1≤A,B≤N
输入数据完全随机。
输入样例:
5 7
1 1 2
2 3 2
4 4 1
3 4 5
5 4 5
2 3 5
4 5 1
输出样例:
11
解析:
差分约束
(1)求不等式组的可行解
源点需要满足的条件:从源点出发,一定可以走到所有的边。
步骤:
【1】先将每个不等式xi<=xj+ck,转化成一条从xj走到xi,长度为ck的一条边。
【2】找一个超级源点,使得该源点一定可以遍历所有的边
【3】从源点求一遍单源最短路
结果1:如果存在负环,则原不等式组一定无解
结果2:如果没有负环,则dist[i]就是原不等式组的一个可行解
(2)如何求最大值或最小值,这里的最值指的是每个变量的最值
结论:如果求的是最小值,则应该求最长路;如果球的是最大值,则应该求最短路;
问题:如何转化xi<=c,其中c是一个常数,这类的不等式。
方法:建立一个超级源点,0,然后建立0->i,长度是c的边即可。
要求xi的最大值为例:求所有从xi出发,构成的不等式链xi<=xj+c1<=xk+c2+c1<=......<=c1+c2+......
所计算出的上界,最终xi的最大值等于所有上界的最小值。
cpp
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<utility>
#include<stack>
#include<queue>
#include<vector>
#include<set>
#include<math.h>
#include<map>
#include<sstream>
#include<deque>
#include<unordered_map>
using namespace std;
typedef long long LL;
const int N = 1e5+5, M = 3e5 + 5;
int n, k;
int h[N], e[M], w[M], ne[M], idx;
LL dist[N];
int q[N], vis[N], cnt[N];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
int spfa() {
memset(dist, -0x3f, sizeof dist);
int hh = 0, tt = 1;
q[0] = 0;
dist[0] = 0;
vis[0] = 1;
while (hh != tt) {
int t = q[--tt];
vis[t] = 0;
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
if (dist[j] < dist[t] + w[i]) {
dist[j] = dist[t] + w[i];
cnt[j] = cnt[t] + 1;
if (cnt[j] >= n + 1)return 0;
if (!vis[j]) {
q[tt++] = j;
vis[j] = 1;
}
}
}
}
return 1;
}
int main() {
cin >> n >> k;
memset(h, -1, sizeof h);
for (int i = 1,a,b,x; i <= k; i++) {
scanf("%d%d%d", &x, &a, &b);
if (x == 1)add(b, a, 0), add(a, b, 0);
else if (x == 2)add(a, b, 1);
else if (x == 3)add(b, a, 0);
else if (x == 4)add(b, a, 1);
else add(a, b, 0);
}
for (int i = 1; i <= n; i++) {
add(0, i, 1);
}
if (!spfa()) {
cout << -1 << endl;
}
else {
LL ret = 0;
for (int i = 1; i <= n; i++)ret += dist[i];
cout << ret << endl;
}
return 0;
}