题意:
给一棵树,每次询问删掉两条边,问剩下的三棵树的最大直径
点10W,询问10W,询问相互独立
Solution:
考虑线段树/倍增维护树的直径
考虑一个点集的区间 [l, r]
而我们知道了有 l <= k < r,
且知道 [l, k] 和 [k + 1, r] 两个区间的最长链的端点及长度
假设两个区间的直径端点分别为 (l1, r1) 和 (l2, r2)
那么 [l, r] 这个区间的直径长度为
dis(l1, r1) dis(l1, l1) dis(l1, r2)
dis(r1, l2) dis(r1, r2) dis(l2, r2)
六个值中的最大值
本题因为操作子树,所以我们维护dfs序的区间最长链即可
证明:
首先有一个结论:
树上任意一个点在树中的最远点是树的直径的某个端点。我们可以用反证法轻易地证明这一点。
再扩展一下,有以下结论:树上任意一个点在树中的一个点集中的最远点是该点集中最长链的一个端点。
其实我们把点集等价地看为一棵虚树,然后就能用相似的证法解决了。
代码:
1 #include <stdio.h> 2 #include <algorithm> 3 4 using namespace std; 5 6 const int N = 2e5 + 5; 7 8 int T, n, m; 9 10 int len, head[N], ST[20][N]; 11 12 struct edge{int u, v, w;}ee[N]; 13 14 int cnt, fa[N], log_2[N], st[N], en[N], dfn[N], dis[N], dep[N], pos[N]; 15 16 struct edges{int to, next, cost;}e[N]; 17 18 inline void add(int u, int v, int w) { 19 e[++ len] = (edges){v, head[u], w}, head[u] = len; 20 e[++ len] = (edges){u, head[v], w}, head[v] = len; 21 } 22 23 inline void dfs1(int u) { 24 st[u] = ++ cnt, dfn[cnt] = u; 25 for (int v, i = head[u]; i; i = e[i].next) { 26 v = e[i].to; 27 if (v == fa[u]) continue; 28 fa[v] = u, dep[v] = dep[u] + 1; 29 dis[v] = dis[u] + e[i].cost, dfs1(v); 30 } 31 en[u] = cnt; 32 } 33 34 inline void dfs2(int u) { 35 dfn[++ cnt] = u, pos[u] = cnt; 36 for (int v, i = head[u]; i; i = e[i].next) { 37 v = e[i].to; 38 if (v == fa[u]) continue; 39 dfs2(v), dfn[++ cnt] = u; 40 } 41 } 42 43 int mmin(int x, int y) { 44 if (dep[x] < dep[y]) return x; 45 return y; 46 } 47 48 inline int lca(int u, int v) { 49 static int w; 50 if (pos[u] > pos[v]) swap(u, v); 51 w = log_2[pos[v] - pos[u] + 1]; 52 return mmin(ST[w][pos[u]], ST[w][pos[v] - (1 << w) + 1]); 53 } 54 55 inline int dist(int u, int v) { 56 int Lca = lca(u, v); 57 return dis[u] + dis[v] - dis[Lca] * 2; 58 } 59 60 inline void build() { 61 for (int i = 1; i <= cnt; i ++) 62 ST[0][i] = dfn[i]; 63 for (int i = 1; i < 20; i ++) 64 for (int j = 1; j <= cnt; j ++) 65 if (j + (1 << (i - 1)) > cnt) ST[i][j] = ST[i - 1][j]; 66 else ST[i][j] = mmin(ST[i - 1][j], ST[i - 1][j + (1 << (i - 1))]); 67 } 68 69 int M; 70 71 struct node { 72 int l, r, dis; 73 }tr[N << 1]; 74 75 inline void update(int o, int o1, int o2) { 76 static int d; 77 static node tmp; 78 if (tr[o1].dis == -1) {tr[o] = tr[o2]; return;} 79 if (tr[o2].dis == -1) {tr[o] = tr[o1]; return;} 80 if (tr[o1].dis > tr[o2].dis) tmp = tr[o1]; 81 else tmp = tr[o2]; 82 d = dist(tr[o1].l, tr[o2].l); 83 if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].l, tmp.dis = d; 84 d = dist(tr[o1].l, tr[o2].r); 85 if (d > tmp.dis) tmp.l = tr[o1].l, tmp.r = tr[o2].r, tmp.dis = d; 86 d = dist(tr[o1].r, tr[o2].l); 87 if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].l, tmp.dis = d; 88 d = dist(tr[o1].r, tr[o2].r); 89 if (d > tmp.dis) tmp.l = tr[o1].r, tmp.r = tr[o2].r, tmp.dis = d; 90 tr[o] = tmp; 91 } 92 93 inline void ask(int s, int t) { 94 if (s > t) return; 95 for (s += M - 1, t += M + 1; s ^ t ^ 1; s >>= 1, t >>= 1) { 96 if (~s&1) update(0, 0, s ^ 1); 97 if ( t&1) update(0, 0, t ^ 1); 98 } 99 } 100 101 inline int get_char() { 102 static const int SIZE = 1 << 23; 103 static char *T, *S = T, buf[SIZE]; 104 if (S == T) { 105 T = fread(buf, 1, SIZE, stdin) + (S = buf); 106 if (S == T) return -1; 107 } 108 return *S ++; 109 } 110 111 inline void in(int &x) { 112 static int ch; 113 while (ch = get_char(), ch > 57 || ch < 48);x = ch - 48; 114 while (ch = get_char(), ch > 47 && ch < 58) x = x * 10 + ch - 48; 115 } 116 117 int main() { 118 int u, v, w, ans; 119 log_2[1] = 0; 120 for (int i = 2; i <= 200000; i ++) 121 if (i == 1 << (log_2[i - 1] + 1)) 122 log_2[i] = log_2[i - 1] + 1; 123 else log_2[i] = log_2[i - 1]; 124 for (in(T); T --; ) { 125 in(n), in(m), cnt = len = 0; 126 for (int i = 1; i <= n; i ++) 127 head[i] = 0; 128 for (int i = 1; i < n; i ++) { 129 in(ee[i].u), in(ee[i].v), in(ee[i].w); 130 add(ee[i].u, ee[i].v, ee[i].w); 131 } 132 dfs1(1); 133 for (M = 1; M < n + 2; M <<= 1); 134 for (int i = 1; i <= n; i ++) 135 tr[i + M].l = tr[i + M].r = dfn[i], tr[i + M].dis = 0; 136 for (int i = n + M + 1; i <= (M << 1) + 1; i ++) 137 tr[i].dis = -1; 138 cnt = 0, dfs2(1), build(); 139 for (int i = M; i; i --) 140 update(i, i << 1, i << 1 | 1); 141 for (int i = 1; i < n; i ++) 142 if (dep[ee[i].u] > dep[ee[i].v]) 143 swap(ee[i].u, ee[i].v); 144 for (int u, v, i = 1; i <= m; i ++) { 145 in(u), in(v), ans = 0; 146 u = ee[u].v, v = ee[v].v, w = lca(u, v); 147 if (w == u || w == v) { 148 if (w != u) swap(u, v); 149 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, n), ans = max(ans, tr[0].dis); 150 tr[0].dis = -1, ask(st[u], st[v] - 1), ask(en[v] + 1, en[u]), ans = max(ans, tr[0].dis); 151 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis); 152 } 153 else { 154 if (st[u] > st[v]) swap(u, v); 155 tr[0].dis = -1, ask(1, st[u] - 1), ask(en[u] + 1, st[v] - 1), ask(en[v] + 1, n), ans = max(ans, tr[0].dis); 156 tr[0].dis = -1, ask(st[u], en[u]), ans = max(ans, tr[0].dis); 157 tr[0].dis = -1, ask(st[v], en[v]), ans = max(ans, tr[0].dis); 158 } 159 printf("%d\n", ans); 160 } 161 } 162 return 0; 163 }