有丶抽象,学到自闭

参考的文章:

zcysky:【学习笔记】dsu on tree

Arpa:[Tutorial] Sack (dsu on tree)

 


 

先康一康模板题吧:CF 600E($Lomsat$ $gelral$)

虽然已经用莫队搞过一遍了(可以参考之前写的博客~),但这个还是差距挺大

我们如果对于每个节点暴力统计答案,是$O(N^2)$的复杂度:最坏情况下整棵树是一条链,对于每个节点的统计平均下来是$O(N)$的

具体是怎么做的呢?

对于以当前节点$x$为根的子树,我们建立$cnt$和$sum$两个数组(其实只要$sum$数组就够用啦)

$cnt[i]$:颜色$i$在子树中出现的次数

$sum[i]$:在子树中出现次数为$i$的颜色,其颜色的序号之和

我们还可以建立一个指针$top$,表示出现次数最多的颜色出现了多少次,在改变$cnt$数组的时候可以顺便维护下$top$

那么,对于这个子树,我们只要跑一边$dfs$,把所有后代全部统计一波,最后的结果就是$ans[x]=sum[top]$

现在我们希望能够降低对于每个节点统计的复杂度

 

$dsu$ $on$ $tree$是$O(N\cdot logN)$的做法,需要用到一些树剖的知识

在这道题目中,拿到了这颗树的连边,我们先用树剖怼上去

不用太着急,只要进行第一个$dfs$、得到$son$数组(即每个节点的重儿子)就够了

接下来的蛇皮操作需要理解一下

对于以节点$x$为根的子树,我们这样计算其结果$ans[x]$:

  1. 将$x$的儿子分成两种,一种是重儿子,另一种是轻儿子
  2. 我们先按照最上面方法的递归计算所有轻儿子的结果,计算完以后,不对$cnt$、$sum$、$top$进行任何保留(保留与否的操作在下一层递归的第$5$步实现)
  3. 我们再递归计算重儿子的的结果,但是计算完后,保留计算重儿子答案时的$cnt$、$sum$、$top$
  4. 结束递归、回到当前节点$x$这层以后,由于保留了计算重儿子时的统计信息,我们此时对重儿子及其子树的信息是完全清楚的,但是我们依然不清楚轻儿子和它们的子树的信息,所以我们再递归的$dfs$一遍轻儿子,将信息放进计算重儿子的数组;这时,我们得到的$cnt$、$sum$、$top$与暴力统计得到的是一模一样的
  5. 但是节点$x$不一定是其父节点$fa[x]$的重儿子!如果不是,那么$dfs$一遍$x$将所有信息清空;否则就保留(这里就是第$2$、$3$步中的保留/不保留的具体实现的位置)

我们在第$2$步仅仅是为了计算轻儿子的结果,且不想让其统计信息干扰我们需要的重儿子的统计信息

绕的地方在于,我们在第$2$步的“不保留”实际上是在计算某个轻儿子后,将以这个轻儿子为根的子树信息全部从$cnt$、$sum$、$top$中抹去(即是在轻儿子的第$5$步实现,并不是在$x$的第$2$步) ←我一开始因为没理解这个自闭了好久

看上去挺暴力的...分析一下为什么是$O(N\cdot logN)$

对于每个节点$x$,它仅可能被 以其祖先为根的子树 统计,所以它被统计的次数与其到根节点的路径长度相关

但是由于我们将重节点的统计信息保留,所以对于每条重链,只会真正意义上$dfs$到$x$一次:在重链的底端;重链上的其余节点可以通过被保留的统计信息了解$x$的情况、不需要$dfs$

综上,$x$被统计的次数即为其到根节点的路径上链的个数,是$O(logN)$级别的

用这个思路写出的代码如下:

#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;

typedef long long ll;
const int N=100005;

int n;
int c[N];
vector<int> v[N];

int fa[N],sz[N],son[N];

inline void dfs(int x,int f)
{
    fa[x]=f;
    sz[x]=1;
    
    for(int i=0;i<v[x].size();i++)
    {
        int next=v[x][i];
        if(next==fa[x])
            continue;
        
        dfs(next,x);
        sz[x]+=sz[next];
        if(!son[x] || sz[son[x]]<sz[next])
            son[x]=next;
    }
}

int top,cnt[N];
ll sum[N],ans[N];

inline void Add(int x,int num)
{
    sum[cnt[c[x]]]-=(ll)c[x];
    cnt[c[x]]+=num;
    sum[cnt[c[x]]]+=(ll)c[x];
    if(sum[top+1])
        top++;
    if(!sum[top])
        top--;
    
    for(int i=0;i<v[x].size();i++)
    {
        int next=v[x][i];
        if(next==fa[x])
            continue;
        Add(next,num);
    }
}

inline void Solve(int x,int keep)
{
    for(int i=0;i<v[x].size();i++)
    {
        int next=v[x][i];
        if(next==fa[x] || next==son[x])
            continue;
        Solve(next,0);
    }
    
    if(son[x])
            Solve(son[x],1);
    
    for(int i=0;i<v[x].size();i++)
    {
        int next=v[x][i];
        if(next==fa[x] || next==son[x])
            continue;
        Add(next,1);
    }
    
    sum[cnt[c[x]]]-=(ll)c[x];
    cnt[c[x]]++;
    sum[cnt[c[x]]]+=(ll)c[x];
    if(sum[top+1])
        top++;
    ans[x]=sum[top];
    
    if(!keep)
        Add(x,-1);
}

int main()
{
//    freopen("input.txt","r",stdin);
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&c[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    
    dfs(1,0);
    Solve(1,1);
    
    for(int i=1;i<=n;i++)
        printf("%lld ",ans[i]);
    return 0;
}
View Code

相关文章: