算法题005:平衡子序列的最大和

题目

题目来源:2926. 平衡子序列的最大和

本文是对 力扣第 370 场周赛 没做出来的最后一题的题解的分析。题解参考了灵茶山艾府的文章

image-20231105162645031

image-20231105162654150

思路

DP

重新读题:一个平衡子序列中连续的两个元素 nums[i]nums[j] 需要满足 $i < j$ 且 $nums[j] - nums[i] \ge j -i$

对上述不等式移项,得到 $nums[j] - j \ge nums[i] - i$

令 $b[i] = nums[i] - i$,则平衡子序列中每个元素需要满足的条件只和自身下标有关。

因此,这道题类似于最长递增子序列,用 DP 解决:

DP 定义:dp[j] 表示以 nums[j] 结尾平衡 子序列里面的 最大元素和

递推式:$dp[j] = nums[j] + \max\set{dp[i] | (i < j) ∧ (b[j] \ge b[i])}$

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import java.util.Arrays;
import java.util.stream.IntStream;

class Solution {
public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
// dp[i]: 以 nums[i] 结尾的子序列的最大元素和
long[] dp = new long[n];
for (int j = 0; j < n; j++) {
dp[j] = nums[j];
long max = 0;
for (int i = 0; i < j; i++) {
if (b[j] >= b[i])
max = Math.max(max, dp[i]);
}
dp[j] += max;
}
return Arrays.stream(dp).max().getAsLong();
}
}

显然,两层循环的时间复杂度是 $O(n^2)$,肯定会超时。我们应该如何优化呢?

树状数组 BIT

根据上述递推式,我们的代码逻辑应该是这样的:

  • 对于每一个 $j$
  • 遍历区间 $[0, j)$,找到其中使得 $b[j] \ge b[i]$ 并且 $dp[i]$ 最大的下标 $i$ —— 区间查询最大值
  • 令 $dp[j] = nums[j] + dp[i]$ —— 单点更新

我们需要让 『区间查询』+ 『单点更新』的时间复杂度严格小于 $O(n)$,这样才不会超时。哪一种现有的数据结构能做到这一点?—— 树状数组(Binary Index Tree)。

我本来以为 BIT 只适合区间查询求和(对于差分数组的优化),但这道题告诉了我们,区间查询的对象可以拓展到最小值和最大值。

BIT 中存放的是什么?<b[j], dp[j]> 的键值对,即 b[j] 是下标,dp[j] 是数组元素。我们要查询所有下标 $\le b[j]$ 的元素的最大值,即前缀区间的最大值。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import java.util.Arrays;
import java.util.stream.IntStream;

public class Solution {
static class BIT {
long[] arr;
int offset;

BIT(int min, int max) {
// 将区间 [min, max] 移动到从 0 开始
offset = -min;
// BIT 的下标 0 必须留空,因此上述区间从 1 开始
offset++;
arr = new long[max - min + 2];
}

int lsb(int i) {
return i & -i;
}

void setMax(int i, long val) {
i += offset;
// 下标 i 会对之后所有区间产生影响,因此是加上 lsb
for (; i < arr.length; i += lsb(i)) {
arr[i] = Math.max(arr[i], val);
}
}

long getMax(int i) {
i += offset;
long ans = Long.MIN_VALUE;
// 前缀区间,因此是减去 lsb
for (; i > 0; i -= lsb(i)) {
ans = Math.max(ans, arr[i]);
}
return ans;
}

}

public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
int max = Arrays.stream(b).max().getAsInt();
int min = Arrays.stream(b).min().getAsInt();
// 根据 b 中元素的范围定制 BIT#arr 的大小
BIT tree = new BIT(min, max);

long ans = Long.MIN_VALUE;
for (int j = 0; j < n; j++) {
// 区间查询最大值
long dp_i = Math.max(tree.getMax(b[j]), 0);
// 单点更新
long dp_j = dp_i + nums[j];
tree.setMax(b[j], dp_j);

ans = Math.max(ans, dp_j);
}
return ans;
}
}

提交该代码:

image-20231106113338051

注意题目的数据范围:$n \le 10^5$ ,$nums[i] \in[-10^9,10^9]$

因此 b[i] 的取值范围约等于 32-bit 整型的取值范围,即 $2^{32}$,这显然会在创建 BIT#arr 时导致 OOM。如何优化空间呢?

离散化

参考 OI-wiki 的描述:

离散化是一种数据处理的技巧,本质上可以看成是一种哈希,其保证数据在哈希以后仍然保持原来的全/偏序关系。

通俗地讲就是当有些数据因为本身很大或者类型不支持,自身无法作为数组的下标来方便地处理,而影响最终结果的只有元素之间的相对大小关系时,我们可以将原来的数据按照排名来处理问题,即离散化。

结合本题来理解:

  • 我们要查询所有下标 $\le b[j]$ 的元素的最大值 —— $b[j]$ 有 $2^{32}$ 个可能取值。
  • 而数组 b 只有 n 个元素。
  • 假设 b[j] 在排序后是第 $k$ 小的元素,那么我们只需要查询前 $k - 1$ 个比 b[j] 小的元素中的最大值。
  • 因此,BIT#arr 的长度也是 n ,且 $n \ll 2^{32}$。
  • 假设我们知道 b[j] 在排序后的下标 k,那么『查询所有下标 $\le b[j]$ 的元素的最大值』就等价于 getMax(k)

从结果上来看,『离散化』应该叫做『归一化』更合适?🤔

离散化的过程也是典型的模板,背出来就行了:

  1. 创建原数组的副本。
  2. 将副本中的值从小到大排序。
  3. 将排序好的副本去重。
  4. 查找原数组的每一个元素在副本中的位置,位置即为排名,将其作为离散化后的值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
* @return ans[i] 表示 b[i] 在排序后是第几小的元素
*/
private int[] discretize(int[] b) {
int[] arr = Arrays.stream(b).sorted().distinct().toArray();
int n = b.length;
int[] ans = new int[n];

for (int i = 0; i < n; i++) {
// b[i] 一定存在于 arr 中
int order = Arrays.binarySearch(arr, b[i]);
// ans[i] >= 1 是因为 BIT 的下标 0 必须留空。实际上可以不用加 1
ans[i] = order + 1;
}
return ans;
}

如果不去重,C++ 使用 std::lower_bound 代替 std::binary_search。但是 Java 中没有这样的函数,因此一定要去重。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import java.util.Arrays;
import java.util.stream.IntStream;

public class Solution {
static class BIT {
long[] arr;

BIT(int n) {
arr = new long[n + 1];
}

int lsb(int i) {
return i & -i;
}

void setMax(int i, long val) {
// 下标 i 会对之后所有区间产生影响,因此是加上 lsb
for (; i < arr.length; i += lsb(i)) {
arr[i] = Math.max(arr[i], val);
}
}

long getMax(int i) {
long ans = Long.MIN_VALUE;
// 前缀区间,因此是减去 lsb
for (; i > 0; i -= lsb(i)) {
ans = Math.max(ans, arr[i]);
}
return ans;
}

}

/**
* @return ans[i] 表示 b[i] 在排序后是第几小的元素
*/
private int[] discretize(int[] b) {
int[] arr = Arrays.stream(b).sorted().distinct().toArray();
int n = b.length;
int[] ans = new int[n];

for (int i = 0; i < n; i++) {
// b[i] 一定存在于 arr 中
int order = Arrays.binarySearch(arr, b[i]);
// ans[i] >= 1 是因为 BIT 的下标 0 必须留空。实际上可以不用加 1
ans[i] = order + 1;
}
return ans;
}

public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
// 离散化
int[] arr = discretize(b);
BIT tree = new BIT(n + 1);

long ans = Long.MIN_VALUE;
for (int j = 0; j < n; j++) {
// b[j] 是第 k 小的元素
int k = arr[j];
// 区间查询最大值
long dp_i = Math.max(tree.getMax(k), 0);
// 单点更新
long dp_j = dp_i + nums[j];
tree.setMax(k, dp_j);

ans = Math.max(ans, dp_j);
}
return ans;
}
}

参考文章