Problem A 树状数组

  给出下列$C++$代码:

  HGOI 20191107 题解

  设区间加操作$modify(l,r)$为调用两次$update(r,1)$和$update(l-1,-1)$

  设$f(l,r)$表示在初始$cnt[i]$全部是$0$的情况下进行$modify(l,r)$操作后,cnt数组中含有非$0$元素的个数。

  给出$T$组询问,输出$\sum\limits_{i=1}^{n} \sum\limits_{j=i}^{n} f(i,j)$的值。

  对于 $100\%$的数据满足,$T\leq 10^4 , n \leq 10^{18}$

Solution :

   上来就做如此恶心的数位$DP$,不过也好,复习了一下数位$DP$。

   观察到$f(l,r)$答案的构成是$l$的二进制个数+$r$的二进制个数 - $l,r$二进制表示形成字符串的lcp的$1$的个数。

   对$n$二进制拆分,从高位到低位依次考虑,设$f[i][op1][op2][j]$当前考虑到第$i$位,当前$r$对$n$是否有限制(op1),当前$l$是否对$r$有限制(op2)。

  而$j$这一维状态,分两次$dp$考虑。

  第一次,我们要求出所有数对$(0 \leq l< r \leq n)$中两个数的二进制$1$的个数的和。

    所以,容易的设$j$为更高位$1$的数目,而整个dp的值表示方案数,总和就是方案数$\times j$的和

  第二次,我们要求出所有数对$(0 \leq l< r \leq n)$中两个数二进制串的$lcp$的$1$的个数

    所以,容易的设$j$为更高位公共$1$的个数,同时整个$dp$的值表示方案数,总和就是方案数$\times j$的和

  按照普通的数位$dp$转移即可,需要注意一些细节,这里就不再赘述一些沙雕错误了,代码用了循环展开,可读性极差。

  复杂度是$O(T {log_2}^2 n)$

# pragma GCC optimize(3,"Ofast") 
# include<bits/stdc++.h>
# define Rint register int  
using namespace std;
const int mo=1e9+7;
long long n;
int a[105];
int f[64][2][2][2*64];
inline void pls(int &a, int b) {
    a = (a + b >= mo ? a + b - mo : a + b);
}
signed main(){
    int T; scanf("%d",&T);
    while (T--) {
        scanf("%lld",&n);a[0]=0; while (n) { a[++a[0]]=n&1; n>>=1;}
        f[a[0]][0][1][0]=1, f[a[0]][1][0][1]=1, f[a[0]][1][1][2]=1;
        int HJCAK = a[0] << 1;
        for (Rint i=a[0];i>=2;i--)
            for (Rint op1=0;op1<=1;op1++){
                Rint op2 = 0;
                for (Rint j=0;j<=HJCAK;j++) {
                    pls(f[i-1][(op1)&&!a[i-1]][op2][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][op2][j+2],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][op2][j+2],f[i][op1][op2][j]);
                    if (!op2) pls(f[i-1][op1&&!a[i-1]][0][j+1],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][0][j+1],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][0][j+1],f[i][op1][op2][j]);
                }
                op2 = 1;
                for (Rint j=0;j<=HJCAK;j++) {
                    pls(f[i-1][(op1)&&!a[i-1]][op2][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][op2][j+2],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][op2][j+2],f[i][op1][op2][j]);
                    if (!op2) pls(f[i-1][op1&&!a[i-1]][0][j+1],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][0][j+1],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][0][j+1],f[i][op1][op2][j]);
                }
            }
        int res1=0;
        for (Rint op1=0;op1<=1;op1++)
            for (Rint j=1;j<=HJCAK;j++)
                pls(res1,1ll*f[1][op1][0][j]*j%mo);
        memset(f,0,sizeof(f));
        f[a[0]][0][1][0]=1, f[a[0]][1][0][0]=1, f[a[0]][1][1][1]=1;
        for (Rint i=a[0];i>=2;i--) 
            for (Rint op1=0;op1<=1;op1++){
                Rint op2 = 0;
                for (Rint j=0;j<=a[0];j++) {
                    pls(f[i-1][op1&&!a[i-1]][op2][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][op2][j+op2],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][op2][j+op2],f[i][op1][op2][j]);
                    if (!op2) pls(f[i-1][op1&&!a[i-1]][0][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][0][j],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][0][j],f[i][op1][op2][j]);
                }
                op2 = 1;
                for (Rint j=0;j<=a[0];j++) {
                    pls(f[i-1][op1&&!a[i-1]][op2][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][op2][j+op2],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][op2][j+op2],f[i][op1][op2][j]);
                    if (!op2) pls(f[i-1][op1&&!a[i-1]][0][j],f[i][op1][op2][j]);
                    if (op1 && a[i-1]) pls(f[i-1][1][0][j],f[i][op1][op2][j]);
                    if (!op1) pls(f[i-1][0][0][j],f[i][op1][op2][j]);
                }
            }
        int res2=0;
        for (Rint op1=0;op1<=1;op1++)
            for (Rint j=1;j<=a[0];j++)
                pls(res2,1ll*f[1][op1][0][j]*j%mo);
        res2 <<= 1; res2 >= mo && (res2 -= mo);
        printf("%d\n",(res1-res2+mo)%mo);
        if(T) memset(f,0,sizeof(f));
    }
    return 0;
}
bit.cpp

相关文章: