参考:https://www.luogu.org/blog/Owencodeisking/post-xue-xi-bi-ji-cdq-fen-zhi-hu-zheng-ti-er-fen

前置技能:树状数组,线段树,分治、归并排序

CDQ分治:

据说是OI大佬陈丹琦发明的

1.三维偏序

思路:

第一维排序,第二维分治,第三维树状数组上查询

考虑分治时区间 [l, m] 对区间 [m+1, r] 的贡献,因为第一维已经排好序,所以区间 [l, m] 的第一维小于区间 [m+1, r]的第一维

然后对于区间 [m+1, r]中的某个元素x,将区间 [l, m] 的第二维小于x的元素的按第三维的权值加入树状数组,

最后区间 [l, m] 对区间 x 的贡献就是查询树状数组中小于x第三维的个数

可以边进行分治边进行归并排序,树状数组要及时清空

算法笔记--CDQ分治 && 整体二分

通过画图我们可以发现,对于每个位置,我们在分治时,它之前的位置对它的贡献都计算过了,所以这种方法是正确的。

因为递归的层数是log(n)层,再加上树状数组,所以时间复杂度是O(n*log(n)^2)

P3810 【模板】三维偏序(陌上花开) 

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
//#define mp make_pair
#define pb push_back
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdi pair<double, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head

const int N = 1e5 + 5, M = 2e5 + 5;
struct Node {
    int x, y, z;
    int ans, cnt;
    bool operator < (const Node & rhs) const {
        if(x == rhs.x) {
            if(y == rhs.y) return z < rhs.z;
            else return y < rhs.y;
        }
        else return x < rhs.x;
    }
}a[N], tmp[N];
int bit[M], res[N], n, k, cnt = 0;
void add(int x, int a) {
    while(x <= k) bit[x] += a, x += x&-x;
}
int sum(int x) {
    int res = 0;
    while(x) res += bit[x], x -= x&-x;
    return res;
}
void cdq(int l, int r) {
    if(l == r) {
        a[l].ans += a[l].cnt-1;
        return ;
    }
    int m = l+r >> 1;
    cdq(l, m);
    cdq(m+1, r);
    int p = l, q = m+1, tp = l;
    while(q <= r) {
        while(p <= m && a[p].y <= a[q].y) add(a[p].z, a[p].cnt), tmp[tp++] = a[p], ++p;
        a[q].ans += sum(a[q].z);
        tmp[tp++] = a[q];
        ++q;
    }
    for (int i = l; i < p; ++i) add(a[i].z, -a[i].cnt);
    while(p <= m) tmp[tp++] = a[p], ++p;
    for (int i = l; i <= r; ++i) a[i] = tmp[i];
}
int main() {
    scanf("%d %d", &n, &k);
    for (int i = 1; i <= n; ++i) scanf("%d %d %d", &a[i].x, &a[i].y, &a[i].z);
    sort(a+1, a+1+n);
    int now = 1;
    for (int i = 2; i <= n; ++i) {
        if(a[i].x == a[i-1].x && a[i].y == a[i-1].y && a[i].z == a[i-1].z) ++now;
        else {
            a[++cnt] = a[i-1];
            a[cnt].cnt = now;
            a[cnt].ans = 0;
            now = 1;
        }
    }
    a[++cnt] = a[n];
    a[cnt].cnt = now;
    a[cnt].ans = 0;
    cdq(1, cnt);
    for (int i = 1; i <= cnt; ++i) res[a[i].ans] += a[i].cnt;
    for (int i = 0; i < n; ++i) printf("%d\n", res[i]);
    return 0;
}

 例题1:CodeForces - 669E

思路:时间看成一个维度就转换成了三维偏序了

代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(4)
#include<bits/stdc++.h>
using namespace std;
#define y1 y11
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
//#define mp make_pair
#define pb push_back
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pli pair<LL, int>
#define pii pair<int, int>
#define piii pair<pii, int>
#define pdi pair<double, int>
#define pdd pair<double, double>
#define mem(a, b) memset(a, b, sizeof(a))
#define debug(x) cerr << #x << " = " << x << "\n";
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//head

const int N = 1e5 + 5;
struct node {
    int a, t, x, ans, id;
    bool operator < (const node & rhs) const {
        return id < rhs.id;
    } 
}a[N], tmp[N];
int n;
map<int, int> cnt;
void cdq(int l, int r) {
    if(l == r) return ;
    int m = l+r >> 1;
    cdq(l, m);
    cdq(m+1, r);
    int p = l, q = m+1, tp = l;
    while(q <= r) {
        while(p <= m && a[p].t <= a[q].t) {
            if(a[p].a == 1) cnt[a[p].x]++; 
            else if(a[p].a == 2) cnt[a[p].x]--;
            tmp[tp++] = a[p], ++p;
        }
        if(a[q].a == 3) a[q].ans += cnt[a[q].x];
        tmp[tp++] = a[q];
        ++q;
    }
    for (int i = l; i < p; ++i)  {
        if(a[i].a == 1) cnt[a[i].x]--; 
        else if(a[i].a == 2) cnt[a[i].x]++;
    }
    while(p <= m) tmp[tp++] = a[p], ++p;
    for (int i = l; i <= r; ++i) a[i] = tmp[i];
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d %d %d", &a[i].a, &a[i].t, &a[i].x), a[i].ans = 0, a[i].id = i;
    cdq(1, n);
    sort(a+1, a+1+n);
    for (int i = 1; i <= n; ++i) if(a[i].a == 3) printf("%d\n", a[i].ans);
    return 0;
} 
View Code

相关文章: