判断树的同构,采用树hash的方式。
树hash定义在有根树上。判断无根树同构的时候,可以比较重心为根的hash值或者比较每个点为根的hash值。
h[x]表示x为根的子树的hash,g[x]表示x为根时全树的hash。
我采用的方法是
h[x] = 1 + ∑h[y] * p[siz[y]]
于是g[x] = g[fa] - h[x] * p[siz[x]] + h[x]
例题1: BJOI2015 树的同构
判断无根树同构,我是比较了每个点为根时的hash值。
1 #include <bits/stdc++.h> 2 3 typedef long long LL; 4 const int N = 60, MO = 998244353; 5 6 struct Edge { 7 int nex, v; 8 }edge[N << 1]; int tp; 9 10 int e[N], n, m, turn, fr[N], p[1000010], top, siz[N], h[N], g[N]; 11 std::vector<int> v[N]; 12 bool vis[1000010]; 13 14 inline void add(int x, int y) { 15 tp++; 16 edge[tp].v = y; 17 edge[tp].nex = e[x]; 18 e[x] = tp; 19 return; 20 } 21 22 inline bool equal(int a, int b) { 23 int len = v[a].size(); 24 if(len != v[b].size()) return false; 25 for(int i = 0; i < len; i++) { 26 if(v[a][i] != v[b][i]) return false; 27 } 28 return true; 29 } 30 31 inline void getp(int n) { 32 for(int i = 2; i <= n; i++) { 33 if(!vis[i]) { 34 p[++top] = i; 35 } 36 for(int j = 1; j <= top && i * p[j] <= n; j++) { 37 vis[i * p[j]] = 1; 38 if(i % p[j] == 0) { 39 break; 40 } 41 } 42 } 43 return; 44 } 45 46 void DFS_1(int x, int f) { 47 siz[x] = 1; 48 h[x] = 1; 49 for(int i = e[x]; i; i = edge[i].nex) { 50 int y = edge[i].v; 51 if(y == f) continue; 52 DFS_1(y, x); 53 h[x] = (h[x] + 1ll * h[y] * p[siz[y]] % MO) % MO; 54 siz[x] += siz[y]; 55 } 56 return; 57 } 58 59 void DFS_2(int x, int f, int V) { 60 g[x] = (h[x] + 1ll * V * p[n - siz[x]] % MO) % MO; 61 v[turn].push_back(g[x]); 62 V = (1ll * V * p[n - siz[x]] % MO + 1) % MO; 63 for(int i = e[x]; i; i = edge[i].nex) { 64 int y = edge[i].v; 65 if(y == f) { 66 continue; 67 } 68 DFS_2(y, x, ((LL)V + h[x] - 1 - 1ll * h[y] * p[siz[y]] % MO + MO) % MO); 69 } 70 return; 71 } 72 73 int main() { 74 getp(1000009); 75 scanf("%d", &m); 76 for(turn = 1; turn <= m; turn++) { 77 scanf("%d", &n); 78 tp = 0; 79 memset(e + 1, 0, n * sizeof(int)); 80 for(int i = 1, x; i <= n; i++) { 81 scanf("%d", &x); 82 if(x) { 83 add(x, i); 84 add(i, x); 85 } 86 } 87 DFS_1(1, 0); 88 DFS_2(1, 0, 0); 89 std::sort(v[turn].begin(), v[turn].end()); 90 /*for(int i = 0; i < n; i++) { 91 printf("%d ", v[turn][i]); 92 } 93 puts("\n");*/ 94 } 95 96 for(int i = 1; i <= m; i++) { 97 fr[i] = i; 98 } 99 for(int i = 2; i <= m; i++) { 100 for(int j = 1; j < i; j++) { 101 if(equal(i, j)) { 102 fr[i] = fr[j]; 103 break; 104 } 105 } 106 } 107 for(int i = 1; i <= m; i++) { 108 printf("%d\n", fr[i]); 109 } 110 return 0; 111 }