题目
题目来源:2926. 平衡子序列的最大和
本文是对 力扣第 370 场周赛 没做出来的最后一题的题解的分析。题解参考了灵茶山艾府的文章。
思路
DP
重新读题:一个平衡子序列中连续的两个元素 nums[i]
和 nums[j]
需要满足 $i < j$ 且 $nums[j] - nums[i] \ge j -i$
对上述不等式移项,得到 $nums[j] - j \ge nums[i] - i$
令 $b[i] = nums[i] - i$,则平衡子序列中每个元素需要满足的条件只和自身下标有关。
因此,这道题类似于最长递增子序列,用 DP 解决:
DP 定义:dp[j]
表示以 nums[j]
结尾的 平衡 子序列里面的 最大元素和 。
递推式:$dp[j] = nums[j] + \max\set{dp[i] | (i < j) ∧ (b[j] \ge b[i])}$
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| import java.util.Arrays; import java.util.stream.IntStream;
class Solution { public long maxBalancedSubsequenceSum(int[] nums) { int n = nums.length; int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray(); long[] dp = new long[n]; for (int j = 0; j < n; j++) { dp[j] = nums[j]; long max = 0; for (int i = 0; i < j; i++) { if (b[j] >= b[i]) max = Math.max(max, dp[i]); } dp[j] += max; } return Arrays.stream(dp).max().getAsLong(); } }
|
显然,两层循环的时间复杂度是 $O(n^2)$,肯定会超时。我们应该如何优化呢?
树状数组 BIT
根据上述递推式,我们的代码逻辑应该是这样的:
- 对于每一个 $j$
- 遍历区间 $[0, j)$,找到其中使得 $b[j] \ge b[i]$ 并且 $dp[i]$ 最大的下标 $i$ —— 区间查询最大值
- 令 $dp[j] = nums[j] + dp[i]$ —— 单点更新
我们需要让 『区间查询』+ 『单点更新』的时间复杂度严格小于 $O(n)$,这样才不会超时。哪一种现有的数据结构能做到这一点?—— 树状数组(Binary Index Tree)。
我本来以为 BIT 只适合区间查询求和(对于差分数组的优化),但这道题告诉了我们,区间查询的对象可以拓展到最小值和最大值。
BIT 中存放的是什么?<b[j], dp[j]>
的键值对,即 b[j]
是下标,dp[j]
是数组元素。我们要查询所有下标 $\le b[j]$ 的元素的最大值,即前缀区间的最大值。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import java.util.Arrays; import java.util.stream.IntStream;
public class Solution { static class BIT { long[] arr; int offset;
BIT(int min, int max) { offset = -min; offset++; arr = new long[max - min + 2]; }
int lsb(int i) { return i & -i; }
void setMax(int i, long val) { i += offset; for (; i < arr.length; i += lsb(i)) { arr[i] = Math.max(arr[i], val); } }
long getMax(int i) { i += offset; long ans = Long.MIN_VALUE; for (; i > 0; i -= lsb(i)) { ans = Math.max(ans, arr[i]); } return ans; }
}
public long maxBalancedSubsequenceSum(int[] nums) { int n = nums.length; int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray(); int max = Arrays.stream(b).max().getAsInt(); int min = Arrays.stream(b).min().getAsInt(); BIT tree = new BIT(min, max);
long ans = Long.MIN_VALUE; for (int j = 0; j < n; j++) { long dp_i = Math.max(tree.getMax(b[j]), 0); long dp_j = dp_i + nums[j]; tree.setMax(b[j], dp_j);
ans = Math.max(ans, dp_j); } return ans; } }
|
提交该代码:
注意题目的数据范围:$n \le 10^5$ ,$nums[i] \in[-10^9,10^9]$
因此 b[i]
的取值范围约等于 32-bit 整型的取值范围,即 $2^{32}$,这显然会在创建 BIT#arr
时导致 OOM。如何优化空间呢?
离散化
参考 OI-wiki 的描述:
离散化是一种数据处理的技巧,本质上可以看成是一种哈希,其保证数据在哈希以后仍然保持原来的全/偏序关系。
通俗地讲就是当有些数据因为本身很大或者类型不支持,自身无法作为数组的下标来方便地处理,而影响最终结果的只有元素之间的相对大小关系时,我们可以将原来的数据按照排名来处理问题,即离散化。
结合本题来理解:
- 我们要查询所有下标 $\le b[j]$ 的元素的最大值 —— $b[j]$ 有 $2^{32}$ 个可能取值。
- 而数组
b
只有 n
个元素。
- 假设
b[j]
在排序后是第 $k$ 小的元素,那么我们只需要查询前 $k - 1$ 个比 b[j]
小的元素中的最大值。
- 因此,
BIT#arr
的长度也是 n
,且 $n \ll 2^{32}$。
- 假设我们知道
b[j]
在排序后的下标 k
,那么『查询所有下标 $\le b[j]$ 的元素的最大值』就等价于 getMax(k)
。
从结果上来看,『离散化』应该叫做『归一化』更合适?🤔
离散化的过程也是典型的模板,背出来就行了:
- 创建原数组的副本。
- 将副本中的值从小到大排序。
- 将排序好的副本去重。
- 查找原数组的每一个元素在副本中的位置,位置即为排名,将其作为离散化后的值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
|
private int[] discretize(int[] b) { int[] arr = Arrays.stream(b).sorted().distinct().toArray(); int n = b.length; int[] ans = new int[n];
for (int i = 0; i < n; i++) { int order = Arrays.binarySearch(arr, b[i]); ans[i] = order + 1; } return ans; }
|
如果不去重,C++ 使用 std::lower_bound
代替 std::binary_search
。但是 Java 中没有这样的函数,因此一定要去重。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| import java.util.Arrays; import java.util.stream.IntStream;
public class Solution { static class BIT { long[] arr;
BIT(int n) { arr = new long[n + 1]; }
int lsb(int i) { return i & -i; }
void setMax(int i, long val) { for (; i < arr.length; i += lsb(i)) { arr[i] = Math.max(arr[i], val); } }
long getMax(int i) { long ans = Long.MIN_VALUE; for (; i > 0; i -= lsb(i)) { ans = Math.max(ans, arr[i]); } return ans; }
}
private int[] discretize(int[] b) { int[] arr = Arrays.stream(b).sorted().distinct().toArray(); int n = b.length; int[] ans = new int[n];
for (int i = 0; i < n; i++) { int order = Arrays.binarySearch(arr, b[i]); ans[i] = order + 1; } return ans; }
public long maxBalancedSubsequenceSum(int[] nums) { int n = nums.length; int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray(); int[] arr = discretize(b); BIT tree = new BIT(n + 1);
long ans = Long.MIN_VALUE; for (int j = 0; j < n; j++) { int k = arr[j]; long dp_i = Math.max(tree.getMax(k), 0); long dp_j = dp_i + nums[j]; tree.setMax(k, dp_j);
ans = Math.max(ans, dp_j); } return ans; } }
|
参考文章