关于这个技巧我甚至都记不清是什么时候学的了,反正就是很早很早之前,当时学了之后看什么子树查询都想上 Dsu On Tree,后来也没怎么写过了,不过这个东西确确实实很强劲。
今天写了一上午教练的题单,大概获得了三天的时间来写自己想写的,就去写写各种莫队吧。
结果写到一个树上莫队,突然想使用这个东西,于是想着顺便写一些这个东西的文章。
由于忘了 CF 账号,所以好多写过的题都忘了是啥了,所以接下来是入门向的,目的防止我自己忘掉。
引入
对于一些树上的问题,每次查询子树,并且这些信息难以合并,需要开个桶或者使用 DS 进行维护,我们可以尝试这个东西。
如果 \(O(1)\) 一次查询或删改,这个东西的复杂度是 \(O(nlogn)\) 的。
这个东西在一定程度上可以替代点分治,注意我指的是一定程度上。
我们的整个算法是有一个特定的流程的,所以写起来是异常的简单。
就按照 OIWIKI 上的例子来引入吧。
给你一棵数,有一些询问,问你一个子树中的颜色数量是多少。
很好理解的题意。
我们发现判断颜色的数量,必须一定程度上去维护每一个颜色的信息,显然这个东西难以进行有效的合并,所以容易想到开桶维护每一个颜色的出现个数,暴力扫描每一个子树,明显这个算法是 \(O(n^2)\) 级别的。
这个桶就是数组维护每个颜色的出现次数,怕我的语言过分抽象。
先不谈我们怎么减少这类做法的时间复杂度,我们先思考有什么办法来减轻我们的工作量。
我们可以尝试这么一个方案:假设我们在处理一棵子树的时候,我们肯定是要 dfs 到下边的子树的,所以我们可以在桶里边保留一些东西的贡献,这个样子我们似乎就可以稍微优化一些了。
我们来讨论一下如何进行合法的保留,仔细思索后我们会发现能保留的东西是很苛刻的。
我们是无法保留两个子树的,所以贪心来想,我们去选择重儿子所在的子树进行保留就行。
这个重儿子就是我们树链剖分的那个重儿子,别告诉我有人连树剖都不会。
这样我们就成功进行了优化,现在这个东西大概就是 \(O(nlogn)\) 级别的了。
什么?这就完了?凭什么?我不信。
我们冷静分析一波,来思考一下每个节点会被扫到多少次?
容易发现,在某个节点到根中路径中,我们除去它本身会被遍历的这一次,这个路径中每出现一次轻边,这个点就会被扫一次,又因为这个数目不会超过 \(log n\) 级别的,所以 \(O(nlogn)\) 的时间复杂度也并不难理解了。
引入的解答
我们来考虑一下如何解决上边的引入。
给一下链接:https://www.luogu.com.cn/problem/U41492
按照上边的思路,我们写一下最经典的树剖,读入,这个我们不多说了,结下来我们讲一下 dsu on tree 的基本流程。
-
不保留对桶中的贡献,尝试去解决轻儿子所在子树
-
保留对桶中的贡献,解决重儿子所在子树
-
再一次遍历轻儿子所在子树,把这一次的东西加入桶中,计算答案。
int siz[MN], son[MN], dfn_cnt, l[MN], r[MN], col[MN], tmp[MN];
void dfs1(int u, int father){siz[u]=1; l[u]=++dfn_cnt; col[dfn_cnt]=tmp[u];for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father) continue;dfs1(v,u); siz[u]+=siz[v];if(siz[son[u]]<siz[v]) son[u]=v;}r[u]=dfn_cnt;
}
int cnt[MN], res, ans[MN];
vector <int> querys[MN];
void Add(int col){cnt[col]++;if(cnt[col]==1) res++;
}
void Del(int col){cnt[col]--;if(cnt[col]==0) res--;
}
int Getans(){return res;
}
void solve(int u, int father, bool keep){for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father||v==son[u]) continue;solve(v,u,false);}if(son[u]) solve(son[u],u,true);for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father||v==son[u]) continue;for(int j=l[v]; j<=r[v]; ++j) Add(col[j]);}Add(tmp[u]);for(auto qry:querys[u]) ans[qry]=Getans();if(!keep){for(int j=l[u]; j<=r[u]; ++j) Del(col[j]);}
}
int n, m;
int main(){ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);cin>>n;for(int i=1,u,v; i<n; ++i){cin>>u>>v; insert(u,v); insert(v,u);}for(int i=1; i<=n; ++i) cin>>tmp[i];cin>>m;for(int i=1; i<=m; ++i){int v; cin>>v; querys[v].push_back(i);}dfs1(1,1); solve(1,1,0);for(int i=1; i<=m; ++i) cout<<ans[i]<<'\n';return 0;
}
解释一下我的代码可能令人不解的点。
我们会发现我们需要不同的遍历,为了更方便扫描子树中的东西,我们使用 dfs 序来更简洁处理问题。
这下明白 l, r 是什么意思了吧。
这个实现千奇百怪,但是并没有什么难度,所以就不多说了。
某些例题。
P9233 [蓝桥杯 2023 省 A] 颜色平衡树
题意
给你一颗节点数为 \(n\) 的树,每个点有一个颜色,询问有多少节点满足它子树中各个颜色节点数相等。
我们还需要 \(O(n\log n)\) 左右的算法。
分析
询问子树上的统计问题,我们可以使用树上启发式合并,这个题还是很套路的。
如何快速判断各个颜色数相等?我们记录一下总共出现的颜色数和每个节点数的颜色数量就行,这两个在树上启发式合并过程中维护就行,很简单,最后询问的时候我们利用所询问节点的颜色判断即可。
算是板子,具体实现看一下代码便懂了。
代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MN=5e5+515;
struct Node{int nxt, to;
}node[MN];
int head[MN], tottt;
inline void insert(int u, int v){node[++tottt].to=v;node[tottt].nxt=head[u];head[u]=tottt;return;
}
int cntid[MN];//统计节点数量为i的颜色个数
int col[MN], tmp[MN];//存颜色的
int cntcol[MN], totcol;//统计一共有多少种颜色
void Add(int c){if(cntcol[c]==0) totcol++;cntcol[c]++;cntid[cntcol[c]-1]--;cntid[cntcol[c]]++;
}
void Del(int c){if(cntcol[c]==1) totcol--;cntcol[c]--;cntid[cntcol[c]+1]--;cntid[cntcol[c]]++;
}
int Getsub(int u){int c=tmp[u];if(totcol==cntid[cntcol[c]]) return true;return false;
}
int depth[MN], fa[MN], siz[MN], son[MN], l[MN], r[MN];
int dfn_cnt;
void dfs1(int u, int father){fa[u]=father; depth[father]=depth[u]-1; siz[u]=1;l[u]=++dfn_cnt; col[dfn_cnt]=tmp[u];for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father) continue;dfs1(v,u);siz[u]+=siz[v];if(siz[son[u]]<siz[v]) son[u]=v;}r[u]=dfn_cnt;
}
int ans=0;
void dfs(int u, int father, bool keep){for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father||v==son[u]) continue;dfs(v,u,false);}if(son[u]) dfs(son[u],u,true);for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==son[u]||v==father) continue;for(int j=l[v]; j<=r[v]; ++j){Add(col[j]);}}Add(tmp[u]);ans+=Getsub(u);if(!keep){for(int i=l[u]; i<=r[u]; ++i) Del(col[i]);}
}
int n;
signed main(){ios::sync_with_stdio(0), cin.tie(0), cout.tie(0); cin>>n;for(int i=1,c,father; i<=n; ++i){cin>>c>>father; tmp[i]=c;if(father) insert(i,father),insert(father,i);}dfs1(1,1);dfs(1,1,false);cout<<ans<<'\n';return 0;
}
CF375D Tree and Queries
这个也是同样的简单,一个道理,我们使用 BIT 快速搞就行。
#include <bits/stdc++.h>
using namespace std;
const int MN=1e6+116;
struct Node{int nxt, to;
}node[MN];
int head[MN], tottt;
inline void insert(int u, int v){node[++tottt].to=v;node[tottt].nxt=head[u];head[u]=tottt; return;
}
int siz[MN], son[MN], col[MN];
int dfn_cnt, l[MN], r[MN], tmp[MN];
void dfs1(int u, int father){siz[u]=1; l[u]=++dfn_cnt; col[dfn_cnt]=tmp[u];for(int i=head[u]; i; i=node[i].nxt){int v=node[i].to;if(v==father) continue;dfs1(v,u); siz[u]+=siz[v];if(siz[son[u]]<siz[v]) son[u]=v;}r[u]=dfn_cnt;
}
int n, m, k, res=0;
int tr[MN], cnt[MN];
int lowbit(int x){return x&(-x);
}
void update(int pos, int val){for(int i=pos; i<MN; i+=lowbit(i)) tr[i]+=val;
}
int qval(int pos){int res=0;for(int i=pos; i; i-=lowbit(i)) res+=tr[i];return res;
}
void Add(int col){if(cnt[col]) update(cnt[col],-1);cnt[col]++;if(cnt[col]==1) res++;update(cnt[col],1);
}
void Del(int col){update(cnt[col],-1);cnt[col]--;if(cnt[col]==0) res--;if(cnt[col]) update(cnt[col],1);
}
int Getans(int k){return qval(k-1);
}
struct Qrys{int id, k;
};
vector <Qrys> querys[MN];
int ans[MN];
void solve(int u, int father, int keep){for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father||v==son[u]) continue;solve(v,u,false);}if(son[u]) solve(son[u],u,true);for(int i=head[u];i;i=node[i].nxt){int v=node[i].to;if(v==father||v==son[u]) continue;for(int j=l[v]; j<=r[v]; ++j) Add(col[j]);}Add(tmp[u]);for(auto qry:querys[u]){ans[qry.id]=res-Getans(qry.k);//cout<<Getans(qry.k)<<'\n';}if(!keep){for(int j=l[u]; j<=r[u]; ++j) Del(col[j]);}
}
int main(){memset(cnt,0,sizeof(cnt));ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);cin>>n>>m; for(int i=1; i<=n; ++i) cin>>tmp[i];for(int i=1,u,v; i<n; ++i){cin>>u>>v; insert(u,v); insert(v,u);}dfs1(1,1);for(int i=1; i<=m; ++i){int u, k; cin>>u>>k;querys[u].push_back({i,k});}solve(1,1,0);for(int i=1; i<=m; ++i) cout<<ans[i]<<'\n';return 0;
}
/*
10 2
75 72 81 90 62 39 32 88 61 58
2 1
3 2
4 3
5 1
6 4
7 1
8 4
9 6
10 8
2 99
2 68
*/
不太想写了,这些题目的具体就是一些实现问题罢了,所以我就懒得写太多了