斜率优化是单调队列优化的推广
用单调队列维护递增的斜率
参考:https://www.cnblogs.com/ka200812/archive/2012/08/03/2621345.html
以例1举例说明:
转移方程为:dp[i] = min(dp[j] + (sum[i] - sum[j])^2 + C)
假设k < j < i, 如果从j转移过来比从k转移过来更优
那么 dp[j] + (sum[i] - sum[j])^2 + C < dp[k] + (sum[i] - sum[k])^2 + C
dp[j] - dp[k] < (sum[i] - sum[k])^2 - (sum[i] - sum[j])^2
dp[j] - dp[k] < -2*sum[i]*sum[k] + sum[k]*sum[k] + 2*sum[i]*sum[j] - sum[j]*sum[j]
dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k] < 2*sum[i]*(sum[j] - sum[k])
(dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k]) < 2*sum[i]
我们观察不等式左边, 它是个斜率的形式, 自变量x为sum, 函数f(x)为dp + sum*sum
我们记这个斜率为g[j, k] = (dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k])
说明1.如果g[j, k] < 2*sum[i] 表示对于dp[i], 从j转移过来比k更优, 反之k更优
说明2.下面我们来考虑着怎么从解集去掉多余的元素, 可以证明可能存在某些元素,无论怎样都不会是最优的,可以去掉这些多余的元素
假设k < j < i
结论:如果g[i, j] < g[j, k], 那么j可以去掉
证明:对于某个i, 如果g[i, j] < 2*sum[i], 那么i比j更优, 结论成立;
如果g[i, j] >= 2*sum[i], 那么g[j, k] > g[i, j] >= 2*sum[i], 那么k比j更优,结论成立.
证毕.
所以如果把所有g[i, j] < g[j, k]的情况中(后面斜率比前面斜率小的情况)的j都去掉, 那么我们就得到相邻两个元素的斜率递增的状况
如下图
下面来说明怎么维护这个解集:
用双端队列维护这个解集, 每次从后面加入元素时, 按照说明2的方式去掉多余元素,使的相邻元素之间构成的斜率保持单调
每次从前面找答案, 由于斜率单调递增, 所以最后一个小于2*sum[i]就是最优的解, 因为这个位置之前的g[i, j]都小于2*sum,
表示后面的比前面更优, 之后的g[i, j] 都大于2*sum, 表示前面的比后面更优, 所以这个点是极值点
又因为sum[i]也具有单调性, 所以下一个极值点的位置肯定大于等于当前极值点, 所以当前极值点之前的都可以从双端队列中移出
ps:所有说明中, k < j < i
例题1:HDU - 3507
思路:维护递增斜率g[i, j] = (dp[i] - dp[j] + sum[i]*sum[i] - sum[j]*sum[j]) / (sum[i] - sum[j])
代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #pragma GCC optimize(4) #include<bits/stdc++.h> using namespace std; #define y1 y11 #define fi first #define se second #define pi acos(-1.0) #define LL long long //#define mp make_pair #define pb emplace_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pli pair<LL, int> #define pii pair<int, int> #define piii pair<pii, int> #define pdd pair<double, double> #define mem(a, b) memset(a, b, sizeof(a)) #define debug(x) cerr << #x << " = " << x << "\n"; #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); //head const int N = 5e5 + 10; int a[N], n, m; LL sum[N], dp[N]; bool g(int k, int j, LL C) { return (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k]) <= C*(sum[j]-sum[k]); } bool gg(int k, int j, int i) { return (dp[i]-dp[j]+sum[i]*sum[i]-sum[j]*sum[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k])*(sum[i]-sum[j]); } deque<int> q; int main() { while(~scanf("%d %d", &n, &m)) { for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i] = sum[i-1]+a[i]; while(!q.empty()) q.pop_back(); q.push_back(0); for (int i = 1; i <= n; ++i) { while(q.size() >= 2) { int a = q.front(); q.pop_front(); int b = q.front(); if(g(a, b, 2*sum[i])) ; else { q.push_front(a); break; } } int j = q.front(); dp[i] = dp[j] + (sum[i]-sum[j])*(sum[i]-sum[j])+m; while(q.size() >= 2) { int b = q.back(); q.pop_back(); int a = q.back(); if(gg(a, b, i)) ; else { q.push_back(b); break; } } q.push_back(i); } printf("%lld\n", dp[n]); } return 0; }