CF1118F2 - Tree Cutting

题意:给你一棵树,每个点被染成了k种颜色之一或者没有颜色。你要切断恰k - 1条边使得不存在两个异色点在同一连通块内。求方案数。

解:对每颜色构建最小斯坦纳树并判交。我用的树上差分实现。

然后把同一颜色的点缩成一个点,在新树上树形DP,fx表示x子树内,x所在连通块内有一个关键点的方案数。hx表示x所在连通块内没有关键点的方案数。

  1 #include <bits/stdc++.h>
  2 
  3 const int N = 300010, MO = 998244353;
  4 
  5 struct Edge {
  6     int nex, v, len;
  7 };
  8 
  9 int n, m, col[N], pw[N << 1], fr[N], imp[N], K, stk[N], top, f[N], h[N];
 10 std::vector<int> v[N];
 11 
 12 inline void ERR() {
 13     puts("0");
 14     exit(0);
 15     return;
 16 }
 17 
 18 struct G {
 19     Edge edge[N << 1]; int tp;
 20     int e[N], d[N], fa[N], pos[N], num, ST[N << 1][20];
 21     G(){}
 22     inline void add(int x, int y, int z = 0) {
 23         edge[++tp].v = y;
 24         edge[tp].len = z;
 25         edge[tp].nex = e[x];
 26         e[x] = tp;
 27         return;
 28     }
 29     void DFS_1(int x, int f) {
 30         d[x] = d[f] + 1;
 31         fa[x] = f;
 32         pos[x] = ++num;
 33         ST[num][0] = x;
 34         for(int i = e[x]; i; i = edge[i].nex) {
 35             int y = edge[i].v;
 36             if(y == f) continue;
 37             DFS_1(y, x);
 38             ST[++num][0] = x;
 39         }
 40         return;
 41     }
 42     inline void pre1(int x = 1) {
 43         DFS_1(x, 0);
 44         return;
 45     }
 46     inline void lcapre() {
 47         for(int j = 1; j <= pw[num]; j++) {
 48             for(int i = 1; i + (1 << j) - 1 <= num; i++) {
 49                 if(d[ST[i][j - 1]] < d[ST[i + (1 << (j - 1))][j - 1]]) {
 50                     ST[i][j] = ST[i][j - 1];
 51                 }
 52                 else {
 53                     ST[i][j] = ST[i + (1 << (j - 1))][j - 1];
 54                 }
 55             }
 56         }
 57         return;
 58     }
 59     inline int lca(int x, int y) {
 60         x = pos[x];
 61         y = pos[y];
 62         if(x > y) std::swap(x, y);
 63         int t = pw[y - x + 1];
 64         if(d[ST[x][t]] < d[ST[y - (1 << t) + 1][t]]) {
 65             return ST[x][t];
 66         }
 67         else {
 68             return ST[y - (1 << t) + 1][t];
 69         }
 70     }
 71     int recol(int x) {
 72         int Col = 0;
 73         for(int i = e[x]; i; i = edge[i].nex) {
 74             int y = edge[i].v;
 75             if(y == fa[x]) {
 76                 continue;
 77             }
 78             int c = recol(y);
 79             if(c && Col && c != Col) {
 80                 ERR();
 81             }
 82             if(c && !Col) {
 83                 Col = c;
 84             }
 85         }
 86         if(col[x]) {
 87             if(Col && col[x] != Col) {
 88                 ERR();
 89             }
 90             else {
 91                 Col = col[x];
 92             }
 93         }
 94         col[x] = Col;
 95         if(fr[x]) {
 96             Col = 0;
 97         }
 98         return Col;
 99     }
100     inline int build_t(G &gr) {
101         /*printf("build virtue tree \n");
102         for(int i = 1; i <= K; i++) printf("%d ", imp[i]);
103         puts("\n");*/
104         
105         stk[top = 1] = imp[1];
106         for(int i = 2; i <= K; i++) {
107             int x = imp[i], y = lca(x, stk[top]);
108             while(top > 1 && pos[y] <= pos[stk[top - 1]]) {
109                 gr.add(stk[top - 1], stk[top], d[stk[top]] - d[stk[top - 1]]);
110                 top--;
111             }
112             if(y != stk[top]) {
113                 gr.add(y, stk[top], d[stk[top]] - d[y]);
114                 stk[top] = y;
115             }
116             stk[++top] = x;
117         }
118         while(top > 1) {
119             gr.add(stk[top - 1], stk[top], d[stk[top]] - d[stk[top - 1]]);
120             top--;
121         }
122         return stk[top];
123     }
124     int cal(int x) {
125         int ans = 1;
126         for(int i = e[x]; i; i = edge[i].nex) {
127             int y = edge[i].v;
128             int t = cal(y);
129             ans = 1ll * ans * t % MO * edge[i].len % MO;
130         }
131         return ans;
132     }
133     void DP(int x, int father) {
134         f[x] = col[x] ? 1 : 0;
135         h[x] = col[x] ? 0 : 1;
136         for(int i = e[x]; i; i = edge[i].nex) {
137             int y = edge[i].v;
138             if(y == father) continue;
139             DP(y, x);
140             /// 
141             int t, t2;
142             t = 1ll * f[x] * (f[y] + h[y]) % MO + 1ll * f[y] * h[x] % MO;
143             t2 = 1ll * h[x] * (h[y] + f[y]) % MO;
144             f[x] = t % MO;
145             h[x] = t2;
146         }
147         //printf("x = %d f[x] = %d h[x] = %d \n", x, f[x], h[x]);
148         return;
149     }
150 }g[3];
151 
152 inline bool cmp(const int &a, const int &b) {
153     return g[1].pos[a] < g[1].pos[b];
154 }
155 
156 void out(int x, G &gr) {
157     for(int i = gr.e[x]; i; i = gr.edge[i].nex) {
158         int y = gr.edge[i].v;
159         if(y == gr.fa[x]) continue;
160         //printf("%d -> %d len = %d \n", x, y, gr.edge[i].len);
161         out(y, gr);
162     }
163     return;
164 }
165 
166 int main() {
167     
168     scanf("%d%d", &n, &m);
169     for(int i = 2; i <= n * 2; i++) pw[i] = pw[i >> 1] + 1;
170     for(int i = 1; i <= n; i++) {
171         scanf("%d", &col[i]);
172         if(col[i]) {
173             v[col[i]].push_back(i);
174         }
175     }
176 
177     for(int i = 1; i < n; i++) {
178         int x, y;
179         scanf("%d%d", &x, &y);
180         g[0].add(x, y);
181         g[0].add(y, x);
182     }
183     g[0].pre1();
184     g[0].lcapre();
185     for(int i = 1; i <= m; i++) {
186         //std::sort(v[i].begin(), v[i].end());
187         int len = v[i].size(), x = v[i][0];
188         for(int j = 1; j < len; j++) {
189             x = g[0].lca(x, v[i][j]);
190         }
191         fr[x] = i;
192     }
193     
194     g[0].recol(1);
195     
196     /*for(int i = 1; i <= n; i++) {
197         printf("i = %d col = %d \n", i, col[i]);
198     }
199     puts("");*/
200     
201     for(int x = 1; x <= n; x++) {
202         //printf("x = %d \n", x);
203         for(int i = g[0].e[x]; i; i = g[0].edge[i].nex) {
204             int y = g[0].edge[i].v;
205             //printf("%d -> %d \n", x, y);
206             if(!col[x] && !col[y]) {
207                 g[1].add(x, y);
208                 //printf("g1 : add %d %d \n", x, y);
209             }
210             else if(!col[x]) {
211                 g[1].add(x, v[col[y]][0]);
212                 //printf("g1 : add %d %d \n", x, v[col[y]][0]);
213             }
214             else if(!col[y]) {
215                 g[1].add(v[col[x]][0], y);
216                 //printf("g1 : add %d %d \n", v[col[x]][0], y);
217             }
218             else if(col[x] != col[y]) {
219                 g[1].add(v[col[x]][0], v[col[y]][0]);
220                 //printf("g1 : add %d %d \n", v[col[x]][0], v[col[y]][0]);
221             }
222         }
223     }
224     
225     g[1].DP(v[1][0], 0);
226     
227     printf("%d\n", f[v[1][0]]);
228     return 0;
229     
230     
231     
232     g[1].pre1(v[1][0]);
233     g[1].lcapre();
234     
235     /*printf("out G1 : \n");
236     out(v[1][0], g[1]);
237     puts("");*/
238     
239     for(int i = 1; i <= m; i++) {
240         imp[++K] = v[i][0];
241     }
242     /*for(int i = 1; i <= n; i++) {
243         if(!col[i]) {
244             imp[++K] = i;
245         }
246     }*/
247     
248     std::sort(imp + 1, imp + K + 1, cmp);
249     int rt = g[1].build_t(g[2]);
250     
251     //printf("out G2 : \n");
252     //out(rt, g[2]);
253     
254     int ans = g[2].cal(rt);
255     printf("%d\n", ans);
256     return 0;
257 }
AC代码

相关文章: