我等蒟蒻爆零之后,问LincHpin大爷:“此等神题可有甚么来头?”
LincHpin:“此三题皆为当年ZXR前辈所留。”
固名之,ZXR专场,233~~~
这个题在BZOJ和HDU上都有身影,一样不一样的吧,反正意思差不多,想法也很相近。
首先就是发现mex函数的一个性质——当左端点固定时,函数值随右端点单调,即$mex(i,j) \leq mex(i,j+1)$。
然后,我们这么做:先$O(N)$求出以1位左端点,右端点分别为$i=1..n$的mex函数值,然后不断将左端点向右移动。
当左端点从$i$移动到$i+1$时,会使得$num_{i}$从维护序列中消失,考虑对那些mex函数值的影响。
首先,$mex_{1}$到$mex_{i}$的函数值已经没用了,可以忽略。
如果我们已经知道$num_{i}$下一次出现的位置是$next_{i}$,那么从$i$到$next_{i}-1$区间内的所有mex值都应当对$num_{i}$取min。
又因为mex函数值一直保持单调性,所有可以二分出来需要修改的区间,然后就变成了区间修改,区间求和,那就是线段树了。
1 #include <cstdio> 2 #include <cstring> 3 4 inline char nextChar(void) 5 { 6 static const int siz = 1 << 20; 7 8 static char buf[siz]; 9 static char *hd = buf + siz; 10 static char *tl = buf + siz; 11 12 if (hd == tl) 13 fread(hd = buf, 1, siz, stdin); 14 15 return *hd++; 16 } 17 18 inline int nextInt(void) 19 { 20 register int ret = 0; 21 register bool neg = false; 22 register char bit = nextChar(); 23 24 for (; bit < 48; bit = nextChar()) 25 if (bit == '-')neg ^= true; 26 27 for (; bit > 47; bit = nextChar()) 28 ret = ret * 10 + bit - '0'; 29 30 return neg ? -ret : ret; 31 } 32 33 typedef long long lnt; 34 35 const int siz = 500005; 36 37 int n, num[siz]; 38 39 int nxt[siz]; 40 int lst[siz]; 41 int mex[siz]; 42 43 bool vis[siz]; 44 45 inline void prework(void) 46 { 47 for (int i = 0; i <= n; ++i) 48 lst[i] = n + 1; 49 50 for (int i = n; i >= 1; --i) 51 { 52 nxt[i] = lst[num[i]]; 53 lst[num[i]] = i; 54 } 55 56 memset(vis, false, sizeof(vis)); 57 58 int ans = 0; 59 60 for (int i = 1; i <= n; ++i) 61 { 62 vis[num[i]] = true; 63 64 while (vis[ans]) 65 ++ans; 66 67 mex[i] = ans; 68 } 69 } 70 71 lnt sum[siz << 2]; 72 lnt tag[siz << 2]; 73 int son[siz << 2]; 74 75 void build(int t, int l, int r) 76 { 77 tag[t] = -1; 78 79 if (l == r) 80 sum[t] = son[t] = mex[l]; 81 else 82 { 83 int mid = (l + r) >> 1; 84 85 build(t << 1, l, mid); 86 build(t << 1 | 1, mid + 1, r); 87 88 son[t] = son[t << 1 | 1]; 89 sum[t] = sum[t << 1] + sum[t << 1 | 1]; 90 } 91 } 92 93 inline void addtag(int t, lnt v, lnt k) 94 { 95 son[t] = v; 96 tag[t] = v; 97 sum[t] = v * k; 98 } 99 100 inline void pushdown(int t, int l, int r) 101 { 102 int mid = (l + r) >> 1; 103 addtag(t << 1, tag[t], mid - l + 1); 104 addtag(t << 1 | 1, tag[t], r - mid); 105 tag[t] = -1; 106 } 107 108 lnt query(int t, int l, int r, int x, int y) 109 { 110 if (l == x && r == y) 111 return sum[t]; 112 else 113 { 114 if (~tag[t]) 115 pushdown(t, l, r); 116 117 int mid = (l + r) >> 1; 118 119 if (y <= mid) 120 return query(t << 1, l, mid, x, y); 121 else if (x > mid) 122 return query(t << 1 | 1, mid + 1, r, x, y); 123 else return 124 query(t << 1, l, mid, x, mid) 125 + query(t << 1 | 1, mid + 1, r, mid + 1, y); 126 } 127 } 128 129 int query(int t, int l, int r, int v) 130 { 131 if (son[t] <= v) 132 return n + 1; 133 134 if (l == r) 135 return l; 136 137 if (~tag[t]) 138 pushdown(t, l, r); 139 140 int mid = (l + r) >> 1; 141 142 int s = son[t << 1]; 143 144 if (s > v) 145 return query(t << 1, l, mid, v); 146 else 147 return query(t << 1 | 1, mid + 1, r, v); 148 } 149 150 void change(int t, int l, int r, int x, int y, lnt v) 151 { 152 if (l == x && r == y) 153 addtag(t, v, r - l + 1); 154 else 155 { 156 if (~tag[t]) 157 pushdown(t, l, r); 158 159 int mid = (l + r) >> 1; 160 161 if (y <= mid) 162 change(t << 1, l, mid, x, y, v); 163 else if (x > mid) 164 change(t << 1 | 1, mid + 1, r, x, y, v); 165 else 166 { 167 change(t << 1, l, mid, x, mid, v); 168 change(t << 1 | 1, mid + 1, r, mid + 1, y, v); 169 } 170 171 son[t] = son[t << 1 | 1]; 172 sum[t] = sum[t << 1] + sum[t << 1 | 1]; 173 } 174 } 175 176 inline void solve(void) 177 { 178 lnt ans = 0LL; 179 180 prework(); 181 182 build(1, 1, n); 183 184 for (int i = 1; i <= n; ++i) 185 { 186 ans += query(1, 1, n, i, n); 187 188 int pos = query(1, 1, n, num[i]); // first > num[i] 189 190 if (pos <= nxt[i] - 1) 191 change(1, 1, n, pos, nxt[i] - 1, num[i]); 192 } 193 194 printf("%lld\n", ans); 195 } 196 197 signed main(void) 198 { 199 freopen("mex.in", "r", stdin); 200 freopen("mex.out", "w", stdout); 201 202 n = nextInt(); 203 204 for (int i = 1; i <= n; ++i) 205 num[i] = nextInt(); 206 207 for (int i = 1; i <= n; ++i) 208 if (num[i] > n) 209 num[i] = n; 210 211 solve(); 212 213 fclose(stdin); 214 fclose(stdout); 215 216 return 0; 217 }