[ZJOI2008]树的统计
第一遍树链剖分,打的很难受。
其中拉闸了,检查真是费劲。
树链剖分是什么?
树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。
树链剖分可以支持链上求和,链上求最值,链上修改等线段树的操作。
但若断开一条边或者连接两个点,保证两个点连接后依然是棵树。这样树链剖分就虚了,因为线段树不支持这种操作,就需要把线段树换成splay,于是LCT = 树剖 + splay。
说明:
重孩子:儿子节点所有孩子中size最大的
轻孩子:儿子节点中除了重儿子的节点
重边:连接重儿子的边
轻边:连接轻儿子的边
重链:重边连成的链
轻链:轻边连成的链
a[i] 表示节点 i 权值
f[i] 表示节点 i 的父亲在原树中的位置
son[i] 表示节点 i 的重儿子在原树中的位置
top[i] 表示节点 i 所在链的顶端节点在原树中的位置,就是深度最小的
size[i] 表示以 i 为根的子树节点个数
tid[i] 表示树中节点 i 剖分后的新编号
rank[i] 表示剖分后的节点 i 在原树中的位置
deep[i] 表示节点 i 深度,根节点深度为 1
实现方法:
第一遍dfs可以预处理出size,deep,f,son数组
第二遍dfs可以预处理出top,tid,rank数组,通过优先搜索重边,然后搜索轻边
树链剖分目的是把树上的边剖分成一个链,就是一个线段,标号是连续的。
为什么要先搜索重边呢?
可以看出,这样搜可以使得重链上的点的dfs序是连续的,可以用线段树来维护。
如何查询呢?
判断两点是否属于同一条重链,如果属于,就直接修改,因为他们是连续的,如果不属于,就从深度大点开始不停地找他父亲跳轻链,其中深度是不停地在变的,也就是说,两个点可能会轮着跳,直到属于同一个重链。现在看来,轻边实际上是连接重链的东西。
1 #include <cstdio> 2 #include <cstring> 3 #include <iostream> 4 #define rt 1, 1, n 5 #define ls o << 1, l, m 6 #define rs o << 1 | 1, m + 1, r 7 8 using namespace std; 9 10 const int maxn = 300001; 11 const int INF = 99999999; 12 int n, m, q, cnt, tim; 13 int a[maxn], head[maxn], to[maxn << 2], next[maxn << 2], deep[maxn], size[maxn]; 14 int son[maxn], top[maxn], f[maxn], tid[maxn], rank[maxn], sumv[maxn], maxv[maxn]; 15 //a节点权值, deep节点深度, size以x为根的子树节点个数, son重儿子, top当前节点所在链的顶端节点 16 //f当前节点父亲, tid保存树中每个节点剖分后的新编号, rank保存剖分后的节点在线段树中的位置 17 18 void add(int x, int y) 19 { 20 to[cnt] = y; 21 next[cnt] = head[x]; 22 head[x] = cnt++; 23 } 24 25 void dfs1(int u, int father)//记录所有重边 26 { 27 int i, v; 28 f[u] = father; 29 size[u] = 1; 30 deep[u] = deep[father] + 1; 31 for(i = head[u]; i != -1; i = next[i]) 32 { 33 v = to[i]; 34 if(v == father) continue; 35 dfs1(v, u); 36 size[u] += size[v]; 37 if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v; 38 } 39 } 40 41 void dfs2(int u, int tp) 42 { 43 int i, v; 44 top[u] = tp; 45 tid[u] = ++tim; 46 rank[tim] = u; 47 if(son[u] == -1) return; 48 dfs2(son[u], tp);//重边 49 for(i = head[u]; i != -1; i = next[i]) 50 { 51 v = to[i]; 52 if(v != son[u] && v != f[u]) dfs2(v, v);//轻边 53 } 54 } 55 56 void pushup(int o) 57 { 58 sumv[o] = sumv[o << 1] + sumv[o << 1 | 1]; 59 maxv[o] = max(maxv[o << 1], maxv[o << 1 | 1]); 60 } 61 62 void updata(int o, int l, int r, int d, int x) 63 { 64 int m = (l + r) >> 1; 65 if(l == r) 66 { 67 sumv[o] = maxv[o] = x; 68 return; 69 } 70 if(d <= m) updata(ls, d, x); 71 else updata(rs, d, x); 72 pushup(o); 73 } 74 75 void build(int o, int l, int r) 76 { 77 int m = (l + r) >> 1; 78 if(l == r) 79 { 80 sumv[o] = maxv[o] = a[rank[l]]; 81 return; 82 } 83 build(ls); 84 build(rs); 85 pushup(o); 86 } 87 88 int querymax(int o, int l, int r, int ql, int qr) 89 { 90 int m = (l + r) >> 1, ans = -INF; 91 if(ql <= l && r <= qr) return maxv[o]; 92 if(ql <= m) ans = max(ans, querymax(ls, ql, qr)); 93 if(m < qr) ans = max(ans, querymax(rs, ql, qr)); 94 pushup(o); 95 return ans; 96 } 97 98 int qmax(int u, int v) 99 { 100 int ans = -INF; 101 while(top[u] != top[v])//判断是否在一条重链上 102 { 103 if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 104 ans = max(ans, querymax(rt, tid[top[u]], tid[u])); 105 u = f[top[u]]; 106 } 107 if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 108 ans = max(ans, querymax(rt, tid[v], tid[u])); 109 return ans; 110 } 111 112 int querysum(int o, int l, int r, int ql, int qr) 113 { 114 int m = (l + r) >> 1, ans = 0; 115 if(ql <= l && r <= qr) return sumv[o]; 116 if(ql <= m) ans += querysum(ls, ql, qr); 117 if(m < qr) ans += querysum(rs, ql, qr); 118 pushup(o); 119 return ans; 120 } 121 122 int qsum(int u, int v) 123 { 124 int ans = 0; 125 while(top[u] != top[v])//判断是否在一条重链上 126 { 127 if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 128 ans += querysum(rt, tid[top[u]], tid[u]); 129 u = f[top[u]]; 130 } 131 if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 132 ans += querysum(rt, tid[v], tid[u]); 133 return ans; 134 } 135 136 int main() 137 { 138 int i, j, x, y; 139 char s[11]; 140 memset(head, -1, sizeof(head)); 141 memset(son, -1, sizeof(son)); 142 scanf("%d", &n); 143 for(i = 1; i < n; i++) 144 { 145 scanf("%d %d", &x, &y); 146 add(x, y); 147 add(y, x); 148 } 149 for(i = 1; i <= n; i++) scanf("%d", &a[i]); 150 dfs1(1, 1);//根节点和他的父亲 151 dfs2(1, 1);//根节点和链头结点 152 build(rt); 153 scanf("%d", &q); 154 for(i = 1; i <= q; i++) 155 { 156 scanf("%s %d %d", s, &x, &y); 157 if(s[1] == 'H') updata(rt, tid[x], y);//把位置为x的点修改为y 158 if(s[1] == 'M') printf("%d\n", qmax(x, y)); 159 if(s[1] == 'S') printf("%d\n", qsum(x, y)); 160 } 161 return 0; 162 }