碎碎念
学完Johnson已经好久了但一直没有时间总结,今天终于有时间了来写一下
其实这个算法还是比较简单的,刚学完最短路的小蒟蒻也可以学会
求点赞 + 评论qwq,支持一下小蒟蒻吧OvO
例题:
大意:我们需要在 \(O(nm)\) 的时间复杂度下处理带负权图的全源最短路
我们熟知的最短路算法有SPFA、dijstra和floyd,但是他们都无法达到优秀的复杂度
-
SPFA虽然理论上是跑不满的,但是稍加卡卡就会被卡到 \(O(n ^ 2 m)\)。
-
dijstra处理不了带负权的图
-
floyd复杂度为 \(O(n ^ 3)\) 也通过不了本题
这时候有人就要问了,我们直接给每条边的边权加上一个固定值,使每个边权都不为负数不就可以了
但这是错误的,我们需要加上特定的一个数才可以保证答案不会变
我有一个绝妙的证明,但这里空白太小,我写不下
咳咳其实我不会证
正片开始
我们还是一样的思路,将负边权全部转化为正数。
那么我们怎么求出"特定的数"呢?
这就是 Johnson 的核心思想了
算法过程
-
新建一个点,并对每一个点都连上边
-
利用 SPFA 算出每个点到刚刚新建的点的最短距离,因为还存在负边权,所以不能使用dijstra
-
给每一条边的边权更新为这条边连接的两个节点的离新建节点最短路 的差值 并加上原来的边权
-
这个时候已经没有负边权了,直接使用dijstra计算最短路
-
因为我们之前加上了一个
两个节点的离新建节点最短路的差值
所以输出时要减去这个结果
注意:dijstra最好使用堆优化,这样时间复杂度就为 \(O(nlogn)\)
总时间复杂度为 \(O(nm log m)\)
code:
cpp
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n, m, vis[5005], h[5005], dis[5005], f[5005];
struct node {
int cnt, head[5005], nxt[10005], to[10005], data[10005];
void add(int x, int y, int w) {
nxt[++ cnt] = head[x];
head[x] = cnt;
to[cnt] = y;
data[cnt] = w;
}
}qwq;
bool SPFA(int x) {
queue<int> q;
for(int i = 1;i <= n;i ++) {
h[i] = 1e9, f[i] = 0;
}
h[x] = 0, f[x] = true;
q.push(x);
while(!q.empty()) {
int xx = q.front();
q.pop();
f[xx] = 0;
for(int i = qwq.head[xx];i != 0;i = qwq.nxt[i]) {
if(h[qwq.to[i]] > h[xx] + qwq.data[i]) {
h[qwq.to[i]] = h[xx] + qwq.data[i];
if(!f[qwq.to[i]]) {
if(++ vis[qwq.to[i]] >= n + 1) {
return 0;
}
f[qwq.to[i]] = 1;
q.push(qwq.to[i]);
}
}
}
}
return 1;
}
void dijstra(int x) {
priority_queue<pair<int, int>, vector<pair<int, int> >, greater<pair<int, int> > > q;
for(int i = 1;i <= n;i ++) {
dis[i] = 1e9;
f[i] = 0;
}
q.push({0, x});
dis[x] = 0;
while(!q.empty()) {
int xx = q.top().second;
q.pop();
if(f[xx] == 0) {
f[xx] = 1;
for(int i = qwq.head[xx];i != 0;i = qwq.nxt[i]) {
if(dis[qwq.to[i]] > dis[xx] + qwq.data[i]) {
dis[qwq.to[i]] = dis[xx] + qwq.data[i];
if(f[qwq.to[i]] == 0) {
q.push({dis[qwq.to[i]], qwq.to[i]});
}
}
}
}
}
}
signed main() {
cin >> n >> m;
for(int i = 1;i <= m;i ++) {
int u, v, w;
cin >> u >> v >> w;
qwq.add(u, v, w);
}
for(int i = 1;i <= n;i ++) {
qwq.add(0, i, 0);
}
if(!SPFA(0)) {
cout << -1;
return 0;
}
for(int u = 1;u <= n;u ++) {
for(int i = qwq.head[u];i != 0;i = qwq.nxt[i]) {
qwq.data[i] += h[u] - h[qwq.to[i]];
}
}
for(int i = 1;i <= n;i ++) {
dijstra(i);
int ans = 0;
for(int j = 1;j <= n;j ++) {
if(dis[j] == 1e9) ans += j * 1e9;
else ans += j * (dis[j] + (h[j] - h[i]));
}
cout << ans << endl;
}
}