题意:给你一棵树,每个点被染成了k种颜色之一或者没有颜色。你要切断恰k - 1条边使得不存在两个异色点在同一连通块内。求方案数。
解:对每颜色构建最小斯坦纳树并判交。我用的树上差分实现。
然后把同一颜色的点缩成一个点,在新树上树形DP,fx表示x子树内,x所在连通块内有一个关键点的方案数。hx表示x所在连通块内没有关键点的方案数。
1 #include <bits/stdc++.h> 2 3 const int N = 300010, MO = 998244353; 4 5 struct Edge { 6 int nex, v, len; 7 }; 8 9 int n, m, col[N], pw[N << 1], fr[N], imp[N], K, stk[N], top, f[N], h[N]; 10 std::vector<int> v[N]; 11 12 inline void ERR() { 13 puts("0"); 14 exit(0); 15 return; 16 } 17 18 struct G { 19 Edge edge[N << 1]; int tp; 20 int e[N], d[N], fa[N], pos[N], num, ST[N << 1][20]; 21 G(){} 22 inline void add(int x, int y, int z = 0) { 23 edge[++tp].v = y; 24 edge[tp].len = z; 25 edge[tp].nex = e[x]; 26 e[x] = tp; 27 return; 28 } 29 void DFS_1(int x, int f) { 30 d[x] = d[f] + 1; 31 fa[x] = f; 32 pos[x] = ++num; 33 ST[num][0] = x; 34 for(int i = e[x]; i; i = edge[i].nex) { 35 int y = edge[i].v; 36 if(y == f) continue; 37 DFS_1(y, x); 38 ST[++num][0] = x; 39 } 40 return; 41 } 42 inline void pre1(int x = 1) { 43 DFS_1(x, 0); 44 return; 45 } 46 inline void lcapre() { 47 for(int j = 1; j <= pw[num]; j++) { 48 for(int i = 1; i + (1 << j) - 1 <= num; i++) { 49 if(d[ST[i][j - 1]] < d[ST[i + (1 << (j - 1))][j - 1]]) { 50 ST[i][j] = ST[i][j - 1]; 51 } 52 else { 53 ST[i][j] = ST[i + (1 << (j - 1))][j - 1]; 54 } 55 } 56 } 57 return; 58 } 59 inline int lca(int x, int y) { 60 x = pos[x]; 61 y = pos[y]; 62 if(x > y) std::swap(x, y); 63 int t = pw[y - x + 1]; 64 if(d[ST[x][t]] < d[ST[y - (1 << t) + 1][t]]) { 65 return ST[x][t]; 66 } 67 else { 68 return ST[y - (1 << t) + 1][t]; 69 } 70 } 71 int recol(int x) { 72 int Col = 0; 73 for(int i = e[x]; i; i = edge[i].nex) { 74 int y = edge[i].v; 75 if(y == fa[x]) { 76 continue; 77 } 78 int c = recol(y); 79 if(c && Col && c != Col) { 80 ERR(); 81 } 82 if(c && !Col) { 83 Col = c; 84 } 85 } 86 if(col[x]) { 87 if(Col && col[x] != Col) { 88 ERR(); 89 } 90 else { 91 Col = col[x]; 92 } 93 } 94 col[x] = Col; 95 if(fr[x]) { 96 Col = 0; 97 } 98 return Col; 99 } 100 inline int build_t(G &gr) { 101 /*printf("build virtue tree \n"); 102 for(int i = 1; i <= K; i++) printf("%d ", imp[i]); 103 puts("\n");*/ 104 105 stk[top = 1] = imp[1]; 106 for(int i = 2; i <= K; i++) { 107 int x = imp[i], y = lca(x, stk[top]); 108 while(top > 1 && pos[y] <= pos[stk[top - 1]]) { 109 gr.add(stk[top - 1], stk[top], d[stk[top]] - d[stk[top - 1]]); 110 top--; 111 } 112 if(y != stk[top]) { 113 gr.add(y, stk[top], d[stk[top]] - d[y]); 114 stk[top] = y; 115 } 116 stk[++top] = x; 117 } 118 while(top > 1) { 119 gr.add(stk[top - 1], stk[top], d[stk[top]] - d[stk[top - 1]]); 120 top--; 121 } 122 return stk[top]; 123 } 124 int cal(int x) { 125 int ans = 1; 126 for(int i = e[x]; i; i = edge[i].nex) { 127 int y = edge[i].v; 128 int t = cal(y); 129 ans = 1ll * ans * t % MO * edge[i].len % MO; 130 } 131 return ans; 132 } 133 void DP(int x, int father) { 134 f[x] = col[x] ? 1 : 0; 135 h[x] = col[x] ? 0 : 1; 136 for(int i = e[x]; i; i = edge[i].nex) { 137 int y = edge[i].v; 138 if(y == father) continue; 139 DP(y, x); 140 /// 141 int t, t2; 142 t = 1ll * f[x] * (f[y] + h[y]) % MO + 1ll * f[y] * h[x] % MO; 143 t2 = 1ll * h[x] * (h[y] + f[y]) % MO; 144 f[x] = t % MO; 145 h[x] = t2; 146 } 147 //printf("x = %d f[x] = %d h[x] = %d \n", x, f[x], h[x]); 148 return; 149 } 150 }g[3]; 151 152 inline bool cmp(const int &a, const int &b) { 153 return g[1].pos[a] < g[1].pos[b]; 154 } 155 156 void out(int x, G &gr) { 157 for(int i = gr.e[x]; i; i = gr.edge[i].nex) { 158 int y = gr.edge[i].v; 159 if(y == gr.fa[x]) continue; 160 //printf("%d -> %d len = %d \n", x, y, gr.edge[i].len); 161 out(y, gr); 162 } 163 return; 164 } 165 166 int main() { 167 168 scanf("%d%d", &n, &m); 169 for(int i = 2; i <= n * 2; i++) pw[i] = pw[i >> 1] + 1; 170 for(int i = 1; i <= n; i++) { 171 scanf("%d", &col[i]); 172 if(col[i]) { 173 v[col[i]].push_back(i); 174 } 175 } 176 177 for(int i = 1; i < n; i++) { 178 int x, y; 179 scanf("%d%d", &x, &y); 180 g[0].add(x, y); 181 g[0].add(y, x); 182 } 183 g[0].pre1(); 184 g[0].lcapre(); 185 for(int i = 1; i <= m; i++) { 186 //std::sort(v[i].begin(), v[i].end()); 187 int len = v[i].size(), x = v[i][0]; 188 for(int j = 1; j < len; j++) { 189 x = g[0].lca(x, v[i][j]); 190 } 191 fr[x] = i; 192 } 193 194 g[0].recol(1); 195 196 /*for(int i = 1; i <= n; i++) { 197 printf("i = %d col = %d \n", i, col[i]); 198 } 199 puts("");*/ 200 201 for(int x = 1; x <= n; x++) { 202 //printf("x = %d \n", x); 203 for(int i = g[0].e[x]; i; i = g[0].edge[i].nex) { 204 int y = g[0].edge[i].v; 205 //printf("%d -> %d \n", x, y); 206 if(!col[x] && !col[y]) { 207 g[1].add(x, y); 208 //printf("g1 : add %d %d \n", x, y); 209 } 210 else if(!col[x]) { 211 g[1].add(x, v[col[y]][0]); 212 //printf("g1 : add %d %d \n", x, v[col[y]][0]); 213 } 214 else if(!col[y]) { 215 g[1].add(v[col[x]][0], y); 216 //printf("g1 : add %d %d \n", v[col[x]][0], y); 217 } 218 else if(col[x] != col[y]) { 219 g[1].add(v[col[x]][0], v[col[y]][0]); 220 //printf("g1 : add %d %d \n", v[col[x]][0], v[col[y]][0]); 221 } 222 } 223 } 224 225 g[1].DP(v[1][0], 0); 226 227 printf("%d\n", f[v[1][0]]); 228 return 0; 229 230 231 232 g[1].pre1(v[1][0]); 233 g[1].lcapre(); 234 235 /*printf("out G1 : \n"); 236 out(v[1][0], g[1]); 237 puts("");*/ 238 239 for(int i = 1; i <= m; i++) { 240 imp[++K] = v[i][0]; 241 } 242 /*for(int i = 1; i <= n; i++) { 243 if(!col[i]) { 244 imp[++K] = i; 245 } 246 }*/ 247 248 std::sort(imp + 1, imp + K + 1, cmp); 249 int rt = g[1].build_t(g[2]); 250 251 //printf("out G2 : \n"); 252 //out(rt, g[2]); 253 254 int ans = g[2].cal(rt); 255 printf("%d\n", ans); 256 return 0; 257 }