(2020.8.22 UPD: 感觉这篇可能需要大修一遍,一些分析现在看来有点不够准确...时间待定)
最近好颓废,什么都学不进去...
感谢两篇:AKMer - 浅谈树分治 言简意赅
LadyLex - 点分治&动态点分治小结 讲解+例题,学到很多东西
经常遇见一类树上的计数题,问的是在某些条件下,选择一些点的方案数
若对于每个点的统计都需要遍历以其为根节点的子树,普通的做法就是$O(n^2)$的,在很多时候是不满足要求的
而这是点分治的专长
点分治是这样进行的:
1. 找到当前树的重心
2. 将重心及重心连出的边全部删去,那么就能将原来的树分割成森林
3. 对于森林中的每棵树,继续找重心;不断地这样递归下去
其中,树的重心$x$表示,以$x$作为树的根,使得(以$x$的儿子为根的)最大子树大小最小
分析一下复杂度
一般来说,用到点分治的时候,需要对于当前子树$O(n)$进行dfs
得出重心也需要一个$O(n)$的dfs
而由于我们选择删去树的重心,所以分裂出的树中最大的不会超过原树大小的一半
所以整个算法的复杂度是$O(n\cdot logn)$
这样看来,点分治相当于从每一个点开始、对子树做一次dfs,但时间复杂度为$O(n\cdot logn)$;于是可以在很多问题中将一个$n$降成一个$logn$
模板题:Luogu P3806 (【模板】点分治1)
这道题可以用点分治这样解决:
若存在一条路径的长度为$k$,则对于当前子树的根$x$,要不在路径上,要不不在路径上
1. 若$x$在路径上,则相当于两个点$u,v$到$x$的距离之和为$k$,且$LCA(u,v)=x$
2. 若$x$不在路径上,则对于每个 $x$的儿子为根的子树 递归下去
由于$k<1\times 10^7$,所以可以开一个$cnt$数组,记录到$x$距离为$dist$的节点数$cnt[dist]$,记得适时清空
这题里面,对$cnt$数组的统计和赋值最好分开做,以免产生影响
#include <cstdio> #include <vector> #include <cstring> using namespace std; typedef pair<int,int> pii; const int N=10005; const int K=10000005; int n,m,val; vector<pii> v[N]; bool flag; bool vis[N]; int root; int sz[N],mx[N]; inline void Find(int x,int fa,int tot) { sz[x]=1; mx[x]=0; for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first; if(vis[nxt] || nxt==fa) continue; Find(nxt,x,tot); sz[x]+=sz[nxt]; mx[x]=max(mx[x],sz[nxt]); } mx[x]=max(mx[x],tot-sz[x]); if(!root || mx[x]<mx[root]) root=x; } int cnt[K]; inline void dfs(int x,int fa,int sum,int type) { if(sum<K) { cnt[sum]+=type; if(type==0 && val-sum>=0 && cnt[val-sum]) flag=true; } for(int i=0;i<v[x].size();i++) { int nxt=v[x][i].first,len=v[x][i].second; if(vis[nxt] || nxt==fa) continue; dfs(nxt,x,sum+len,type); } } inline void Calc(int x,int tot) { root=0; Find(x,0,tot); int cur=root; Find(cur,0,tot); cnt[0]++; for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first,len=v[cur][i].second; if(vis[nxt]) continue; dfs(nxt,cur,len,0); dfs(nxt,cur,len,1); } dfs(cur,0,0,-1); vis[cur]=true; for(int i=0;i<v[cur].size();i++) { int nxt=v[cur][i].first; if(vis[nxt]) continue; Calc(nxt,sz[nxt]); } vis[cur]=false; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<n;i++) { int x,y,w; scanf("%d%d%d",&x,&y,&w); v[x].push_back(pii(y,w)); v[y].push_back(pii(x,w)); } while(m--) { scanf("%d",&val); flag=false; memset(vis,false,sizeof(vis)); Calc(1,n); printf(flag?"AYE\n":"NAY\n"); } return 0; }