这里学习一下DP的正确姿势。
也为了ZJOI2019去水一下做一些准备
题解就随便写写啦.
后续还是会有专题练习和综合练习的.
给出$n \times m$矩阵每次在每一行取n个数,一共取m次,
第i次取数的权值是$2^i$,给出一个取数的顺序,最大化取完所有数的贡献和。
输出贡献和。
对于100%的数据$1\leq n,m \leq 80,0<a_{i,j} \leq 10^3 $
需要使用高精度,考虑一个DP,$f[i][j][k]$表示第i行,共取j次,其中k次在前面,显然(j-k)在后面
第j次取数在前面取的那么$f[i][j][k]$是由$f[i][j-1][k-1]$转移而来,得dp方程$f[i][j][k]=f[i][j-1][k-1]+2^j \times a[k]$ (取的数在从左往右数第$k$个)
第j次取数在后面取的那么$f[i][j][k]$是由$f[i][j-1][k]$转移而来,得dp方程$f[i][j-1][k]+2^j \times a[m-(j-k)+1]$ (取的数在从左向右数第$m-(j-k)+1$个)
然后整理可得DP方程$f[i][j][k]=max{f[i][j-1][k-1]+2^j \times a[k] , f[i][j-1][k]+2^j \times a[m-(j-k)+1] }$
考虑边界条件对于$f[i][j][k]$是有意义的当且仅当$0 \leq j \leq k$,然后使用高精度计算(预处理2的幂的Lint数)
复杂度$O(km^3)$,其中k是高精度数的位数可视为常数
code:
# include <bits/stdc++.h> # define fp(i,s,t) for(int i=s;i<=t;i++) # define int long long using namespace std; const int N=85; const int L=50; char s[L]; inline int read() { int X=0,w=0; char c=0; while(c<'0'||c>'9') {w|=c=='-';c=getchar();} while(c>='0'&&c<='9') X=(X<<3)+(X<<1)+(c^48),c=getchar(); return w?-X:X; } struct Lint{ int len,num[L]; Lint () { len=0; memset(num,0,sizeof(num));} Lint change(int x) { Lint a; if (x==0) { a.num[1]=0;a.len=1; return a;} a.len=0; while (x>0) a.num[++a.len]=x%10,x/=10; return a; } void print(Lint x) { for (int i=x.len;i>=1;i--) putchar(x.num[i]+'0'); putchar('\n'); } Lint operator + (const Lint &t) const { Lint ans; memset(ans.num,0,sizeof(ans.num)); ans.len=max(len,t.len); for (int i=1;i<=ans.len;i++) { ans.num[i]+=num[i]+t.num[i]; ans.num[i+1]+=ans.num[i]/10; ans.num[i]%=10; } if (ans.num[ans.len+1]>0) ans.len++; return ans; } Lint operator * (const Lint &t) const { Lint ans; for (int i=1;i<=len;i++) for (int j=1;j<=t.len;j++) ans.num[i+j-1]+=num[i]*t.num[j]; for (int i=1;i<=len+t.len;i++) { ans.num[i+1]+=ans.num[i]/10; ans.num[i]%=10; } if (ans.num[len+t.len]>0) ans.len=len+t.len; else ans.len=len+t.len-1; return ans; } bool operator < (const Lint &t) const { if (len>t.len) return false; else if (len<t.len) return true; for (int i=len;i>=1;i--) if (num[i]<t.num[i]) return true; else if (num[i]>t.num[i]) return false; return false; } }rd; Lint Max(Lint x,Lint y){if (x<y) return y; else return x;} int n,m; Lint f[N][N],a[N][N],ans; Lint pw[N]; signed main() { scanf("%lld%lld",&n,&m); fp(i,1,n) fp(j,1,m) a[i][j]=rd.change(read()); pw[0]=rd.change(1); Lint num_2=rd.change(2); fp(i,1,m) pw[i]=pw[i-1]*num_2; fp(i,1,n) { Lint ret=rd.change(0); memset(f,0,sizeof(f)); fp(j,1,m) fp(k,0,j) { if (k!=0) f[j][k]=f[j-1][k-1]+pw[j]*a[i][k]; if (j-1>=k-1) f[j][k]=max(f[j][k],f[j-1][k]+pw[j]*a[i][m-j+k+1]); } fp(k,0,m) ret=Max(ret,f[m][k]); ans=ans+ret; } rd.print(ans); return 0; }