Problem A awesome
给出一个序列$A_i$,任取序列中三个数组成三元组$(a_i , a_j , a_k)$。
输出本质不同的且$abc \equiv 1 (mod P)$且满足$a \leq b \leq c$的三元组$(a,b,c)$的组数。
对于$100\%$的数据满足$n \leq 2333 , P \in Prime$
Soltuion :
本题显然会卡常数,并且出了非常暧昧的数据范围。
设$n$不去重前的数据规模,而$m$是去重前的数据规模。
我们可以使用$O(n)$的暴力处理三元组中三个数都相同的情况。
我们可以使用$O(n^2)$暴力处理三元组中两个数的情况。
我们可以使用$O(n^2 log_2 n)$暴力处理三元组中所有数都不同的情况。
我们将数组中每个数以其模$P$的余数为键值插入到hash表中。
然后对于去重后的数组$O(m^2)$枚举两个不同的数,然后通过逆元从hash表中取出对应的数集即可。
可以通过一个$lower_bound$来计数。
时间复杂度为$O(n^2 log_2 n)$。
# pragma GCC optimize(3) # include <bits/stdc++.h> # define int long long # define hash Hash using namespace std; const int N=3e3+10; int n,p,a[N],s[N],inv[N]; vector<int>tmp; int Pow(int x,int n,int mo) { int ans = 1; while (n) { if (n&1) ans =ans * x %mo; x =x *x % mo; n>>=1; } return ans % mo; } struct Node {int key; vector<int>v;}; vector<Node> hash[1927]; void insert(int key,int val) { int to = key % 1926; for (int i=0;i<hash[to].size();i++) { if (key == hash[to][i].key) { hash[to][i].v.push_back(val); return ; } } Node tmp; tmp.key=key; tmp.v.push_back(val); hash[to].push_back(tmp); } vector<int>ttt; bool find(int key) { int to = key % 1926; for (int i=0;i<hash[to].size();i++) { if (key == hash[to][i].key) { ttt = hash[to][i].v; return 1; } } return 0; } signed main() { scanf("%lld%lld",&n,&p); for (int i=1;i<=n;i++) scanf("%lld",&a[i]); sort(a+1,a+1+n); for (int i=1;i<=n;i++) { for (int j=1;j<=n;j++) if (a[i] == a[j]) s[i]++; } for (int i=1;i<=n;i++) inv[i] = Pow(a[i],p-2,p); for (int i=1;i<=n;i++) tmp.push_back(a[i]); tmp.erase(unique(tmp.begin(),tmp.end()),tmp.end()); for (int i=0;i<tmp.size();i++) insert(tmp[i]%p,tmp[i]); int ans = 0; for (int i=1;i<=n;i++) { if (a[i] == a[i-1] && i!=1) continue; if (s[i]<3) continue; if (a[i] * a[i] % p * a[i] % p == 1) ans++; } for (int i=1;i<=n;i++) { if (a[i] == a[i-1] && i!=1) continue; if (s[i] < 2) continue; for (int j=0;j<tmp.size();j++) if (tmp[j] != a[i] && a[i] * a[i] % p * tmp[j] % p == 1) ans++; } for (int i=1;i<=n;i++) { if (a[i] == a[i-1] && i!=1) continue; for (int j=i+1;j<=n;j++) { if (a[j] == a[j-1] && j!=i+1) continue; if (a[i] >= a[j]) continue; if (a[i] % p == 0 || a[j] % p == 0) continue; int key = 1 * inv[i] * inv[j] % p; if (!find(key)) continue; ans += ttt.end()-upper_bound(ttt.begin(),ttt.end(),a[j]); } } printf("%lld\n",ans); return 0; }