可以在O(log x) 中计算组合数,忽略对任意大小的整数执行矩阵乘法所需的时间。
组合的数量可以表示为重复。让S(n) 成为通过从集合中添加数字来生成数字n 的方法的数量。复发是
S(n) = a_1*S(n-1) + a_2*S(n-2) + ... + a_15*S(n-15),
其中a_i 是i 在集合中出现的次数。此外,对于 nA 来表述(或更小是集合中的最大数更小)。那么,如果你有一个列向量V 包含
S(n-14) S(n-13) ... S(n-1) S(n),
那么矩阵乘法A*V的结果将是
S(n-13) S(n-12) ... S(n) S(n+1).
A矩阵定义如下:
0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 1 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
a_15 a_14 a_13 a_12 a_11 a_10 a_9 a_8 a_7 a_6 a_5 a_4 a_3 a_2 a_1
其中a_i 定义如上。通过手动执行乘法,可以立即看到该矩阵与S(n_14) ... S(n) 向量的乘法有效的证明;向量中的最后一个元素将等于n+1 的递归右侧。通俗地说,矩阵中的元素将列向量中的元素向上移动一行,矩阵的最后一行计算最新的项。
为了计算任意项S(n)的递归就是计算A^n * V,其中V等于
S(-14) S(-13) ... S(-1) S(0) = 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.
为了将运行时降低到O(log x),可以使用exponentiation by squaring 来计算A^n。
实际上,完全忽略列向量就足够了,A^n 的右下方元素包含所需的值S(n)。
如果上面的解释很难理解,我提供了一个 C 程序,它可以按照我上面描述的方式计算组合的数量。注意它会很快溢出一个 64 位整数。使用GMP,您将能够更进一步地使用高精度浮点类型,尽管您不会得到确切的答案。
不幸的是,我看不到一个快速的方法来获得 x=10^18 等数字的确切答案,因为答案可能比 10^x 大得多。
#include <stdio.h>
typedef unsigned long long ull;
/* highest number in set */
#define N 15
/* perform the matrix multiplication out=a*b */
void matrixmul(ull out[N][N],ull a[N][N],ull b[N][N]) {
ull temp[N][N];
int i,j,k;
for(i=0;i<N;i++) for(j=0;j<N;j++) temp[i][j]=0;
for(k=0;k<N;k++) for(i=0;i<N;i++) for(j=0;j<N;j++)
temp[i][j]+=a[i][k]*b[k][j];
for(i=0;i<N;i++) for(j=0;j<N;j++) out[i][j]=temp[i][j];
}
/* take the in matrix to the pow-th power, return to out */
void matrixpow(ull out[N][N],ull in[N][N],ull pow) {
ull sq[N][N],temp[N][N];
int i,j;
for(i=0;i<N;i++) for(j=0;j<N;j++) temp[i][j]=i==j;
for(i=0;i<N;i++) for(j=0;j<N;j++) sq[i][j]=in[i][j];
while(pow>0) {
if(pow&1) matrixmul(temp,temp,sq);
matrixmul(sq,sq,sq);
pow>>=1;
}
for(i=0;i<N;i++) for(j=0;j<N;j++) out[i][j]=temp[i][j];
}
void solve(ull n,int *a) {
ull m[N][N];
int i,j;
for(i=0;i<N;i++) for(j=0;j<N;j++) m[i][j]=0;
/* create matrix from a[] array above */
for(i=2;i<=N;i++) m[i-2][i-1]=1;
for(i=1;i<=N;i++) m[N-1][N-i]=a[i-1];
matrixpow(m,m,n);
printf("S(%llu): %llu\n",n,m[N-1][N-1]);
}
int main() {
int a[]={1,1,0,0,0,0,0,1,0,0,0,0,0,0,0};
int b[]={1,1,1,1,1,0,0,0,0,0,0,0,0,0,0};
solve(13,a);
solve(80,a);
solve(15,b);
solve(66,b);
return 0;
}