对于1<=i<=k, 把 a[c[i]] 改为c[i % k + 1]。给定n,k和数组b,判断能否得到数组b。
思路:
cpp
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
#define fi first
#define se second
#define lson p << 1
#define rson p << 1 | 1
const int maxn = 1e6 + 5, inf = 1e18, maxm = 4e4 + 5;
const int mod = 1e9 + 7;
// const int mod = 998244353;
const int N = 1e6;
// int a[505][5005];
// bool vis[505][505];
// char s[505][505];
int a[maxn], b[maxn];
bool vis[maxn];
string s;
int n, m;
struct Node{
int val, id;
bool operator<(const Node &u)const{
return val < u.val;
}
// int x, y;
// int l, r, j;
}c[maxn];
int ans[maxn];
vector<int> G[maxn], g[maxn];
int dfn[maxn], low[maxn], in_st[maxn];
int col[maxn], cnt_col, tot;
int sum[maxn], siz[maxn];
stack<int> st;
void tarjan(int u){
dfn[u] = ++tot;
low[u] = dfn[u];
in_st[u] = 1;
st.push(u);
for(auto v : G[u]){
if(!dfn[v]){//v没遍历过,先更新v,更新u
tarjan(v);
low[u] = min(low[u], low[v]);
}
else{
if(in_st[v]) //后向边或者横向边,就更新
low[u] = min(low[u], dfn[v]);
}
}
if(low[u] == dfn[u]){//找到该强连通分量的最先访问到的点
col[u] = ++cnt_col;//把该分量的点都标记为同一颜色
sum[cnt_col] += a[u];
siz[cnt_col]++;
while(!st.empty() && st.top() != u){
int x = st.top();
st.pop();
in_st[x] = 0;
col[x] = cnt_col;
sum[cnt_col] += a[x];
siz[cnt_col]++;
}
if(!st.empty()){
st.pop();
in_st[u] = 0;
}
}
}
void solve(){
int res = 0;
int q, k;
cin >> n >> k;
tot = 0;
cnt_col = 0;
for(int i = 1; i <= n; i++){
G[i].clear();
g[i].clear();
siz[i] = 0;
in_st[i] = 0;
dfn[i] = low[i] = 0;
}
while(!st.empty()){
st.pop();
}
for(int i = 1; i <= n; i++){
cin >> a[i];
G[i].pb(a[i]);
}
if(k == 1){
for(int i = 1; i <= n; i++){
if(a[i] != i){
cout << "No\n";
return;
}
}
cout << "Yes\n";
return;
}
for(int i = 1; i <= n; i++){
if(a[i] == i){
cout << "No\n";//不能自环
return;
}
if(!dfn[i]){
tarjan(i);
}
}
for(int i = 1; i <= cnt_col; i++){
if(siz[i] > 1 && siz[i] != k){
cout << "No\n";
return;
}
}
cout << "Yes\n";
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
int T = 1;
cin >> T;
while (T--)
{
solve();
}
return 0;
}
法二:
cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
int n, k;
std::cin >> n >> k;
std::vector<int> b(n);
for (int i = 0; i < n; i++) {
std::cin >> b[i];
b[i]--;
}
if (k == 1) {
for (int i = 0; i < n; i++) {
if (b[i] != i) {
std::cout << "NO\n";
return;
}
}
std::cout << "YES\n";
return;
}
std::vector<int> vis(n, -1);
for (int i = 0; i < n; i++) {
int j = i;
while (vis[j] == -1) {
vis[j] = i;
j = b[j];
}
if (vis[j] == i) {
int len = 0;
int x = j;
do {
len++;
x = b[x];
} while (x != j);
if (len != k) {
std::cout << "NO\n";
return;
}
}
}
std::cout << "YES\n";
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}