Algorithm Problem 005: Maximum Balanced Subsequence Sum

Problem

Problem Source: 2926. Maximum Balanced Subsequence Sum

This article is an analysis of the solution to the last problem in LeetCode Weekly Contest 370, which I failed to solve. The solution refers to the article by endlesscheng

Problem Description:

You are given a 0-indexed integer array nums.

A subsequence of nums having length k and consisting of indices i0 < i1 < ... < ik-1 is balanced if the following holds:

  • nums[ij] - nums[ij-1] >= ij - ij-1, for every j in the range [1, k - 1].

A subsequence of nums having length 1 is considered balanced.

Return an integer denoting the maximum possible sum of elements in a balanced subsequence of nums.

A subsequence of an array is a new non-empty array that is formed from the original array by deleting some (possibly none) of the elements without disturbing the relative positions of the remaining elements.

Example 1:

1
2
3
4
5
6
7
8
Input: nums = [3,3,5,6]
Output: 14
Explanation: In this example, the subsequence [3,5,6] consisting of indices 0, 2, and 3 can be selected.
nums[2] - nums[0] >= 2 - 0.
nums[3] - nums[2] >= 3 - 2.
Hence, it is a balanced subsequence, and its sum is the maximum among the balanced subsequences of nums.
The subsequence consisting of indices 1, 2, and 3 is also valid.
It can be shown that it is not possible to get a balanced subsequence with a sum greater than 14.

Example 2:

1
2
3
4
5
6
Input: nums = [5,-1,-3,8]
Output: 13
Explanation: In this example, the subsequence [5,8] consisting of indices 0 and 3 can be selected.
nums[3] - nums[0] >= 3 - 0.
Hence, it is a balanced subsequence, and its sum is the maximum among the balanced subsequences of nums.
It can be shown that it is not possible to get a balanced subsequence with a sum greater than 13.

Example 3:

1
2
3
4
Input: nums = [-2,-1]
Output: -1
Explanation: In this example, the subsequence [-1] can be selected.
It is a balanced subsequence, and its sum is the maximum among the balanced subsequences of nums.

Constraints:

  • 1 <= nums.length <= 105
  • -109 <= nums[i] <= 109

Approach

DP

Let’s change the formulation of the problem: consecutive elements nums[i] and nums[j] in a balanced subsequence must satisfy $i < j$ and $nums[j] - nums[i] \ge j - i$.

By rearranging the inequality above, we get $nums[j] - j \ge nums[i] - i$.

Let $b[i] = nums[i] - i$, then the requirement for each element in a balanced subsequence only depends on its own value and index, not those of other elements.

Therefore, this problem is similar to the Longest Increasing Subsequence, and we can solve it using DP:

DP Definition: dp[j] represents the maximum sum of elements in a balanced subsequence ending at nums[j].

Recurrence relation: $dp[j] = nums[j] + \max\set{dp[i] | (i < j) ∧ (b[j] \ge b[i])}$

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import java.util.Arrays;
import java.util.stream.IntStream;

class Solution {
public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
// dp[i]: maximum sum of elements in a subsequence ending at nums[i]
long[] dp = new long[n];
for (int j = 0; j < n; j++) {
dp[j] = nums[j];
long max = 0;
for (int i = 0; i < j; i++) {
if (b[j] >= b[i])
max = Math.max(max, dp[i]);
}
dp[j] += max;
}
return Arrays.stream(dp).max().getAsLong();
}
}

Clearly, the time complexity of the above code is $O(n^2)$ due to the nested loops, and it is likely to time out. How can we improve?

Binary Indexed Tree (BIT)

Based on the recurrence relation above, our code logic should be as follows:

  • For each $j$,
  • Traverse the interval $[0, j)$ and find the index $i$ that satisfies $b[j] \ge b[i]$ and maximizes $dp[i]$, which is a range query for the maximum value.
  • Set $dp[j] = nums[j] + dp[i]$, which is a point update.

We need the time complexity of “range query” + “point update” to be strictly less than $O(n)$ to avoid time-outs. The Binary Indexed Tree (BIT) fits this requirement.

Initially, I thought BIT was only suitable for range sum queries (for optimizing difference arrays), but this problem shows that range queries can also extend to minimum and maximum values.

The values stored in the BIT are key-value pairs <b[j], dp[j]>, where b[j] is the index and dp[j] is the array element. To query the maximum value of elements whose indices $\le b[j]$, we are essentially querying the maximum value of a prefix interval.

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import java.util.Arrays;
import java.util.stream.IntStream;

public class Solution {
static class BIT {
long[] arr;
int offset;

BIT(int min, int max) {
// Move the interval [min, max] to start from 0
offset = -min;
// Leave index 0 of BIT empty, so the interval starts from 1
offset++;
arr = new long[max - min + 2];
}

int lsb(int i) {
return i & -i;
}

void setMax(int i, long val) {
i += offset;
// Index i affects all subsequent intervals, so we add lsb
for (; i < arr.length; i += lsb(i)) {
arr[i] = Math.max(arr[i], val);
}
}

long getMax(int i) {
i += offset;
long ans = Long.MIN_VALUE;
// Prefix interval, so we subtract lsb
for (; i > 0; i -= lsb(i)) {
ans = Math.max(ans, arr[i]);
}
return ans;
}

}

public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
int max = Arrays.stream(b).max().getAsInt();
int min = Arrays.stream(b).min().getAsInt();
// Customize the size of BIT#arr based on the range of elements in b
BIT tree = new BIT(min, max);

long ans = Long.MIN_VALUE;
for (int j = 0; j < n; j++) {
// Range query for the maximum value
long dp_i = Math.max(tree.getMax(b[j]), 0);
// Point update
long dp_j = dp_i + nums[j];
tree.setMax(b[j], dp_j);

ans = Math.max(ans, dp_j);
}
return ans;
}
}

Submit this code:

image-20231106113338051

We got an OOM.

Note the data range of the problem: $n \le 10^5$ ,$nums[i] \in[-10^9,10^9]$

Therefore, the range of values for b[i] is approximately equal to the range of 32-bit integers, i.e., $2^{32}$, which will result in an Out Of Memory (OOM) error when creating BIT#arr. How can we optimize the space?

Discretization

Reference to OI-wiki:

Discretization is a data processing technique, essentially a form of hashing, that ensures that data remains in its original full/partial order after hashing.

In simple terms, when some data is too large or of a type that does not support using it as an array index conveniently, and only the relative size relationship between elements affects the final result, we can process the problem by treating the data based on their ranking, i.e., discretization.

Let’s combine the concept with the problem:

  • We want to query the maximum value of all indices $\le b[j]$ —— there are $2^{32}$ possible values for b[j].
  • However, the array b only has n elements.
  • Assume that b[j] is the $k$-th smallest element after sorting; then, we only need to query the maximum value among the first $k - 1$ elements smaller than b[j].
  • Therefore, the length of BIT#arr is also n, and $n \ll 2^{32}$.
  • If we know the index k where b[j] is located after sorting, the operation ‘query the maximum value of all indices less than or equal to b[j]’ is equivalent to getMax(k).

In terms of effects, ‘Discretization’ should probably be called ‘Normalization’? 🤔

The process of discretization is a typical template. Just memorize it:

  1. Create a copy of the original array.
  2. Sort the values in the copy in ascending order.
  3. Remove duplicates from the sorted copy.
  4. Find the position of each element in the original array in the sorted copy. This position becomes the ranking, and it is used as the discretized value.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
* @return ans[i] indicates the ranking of b[i] after sorting
*/
private int[] discretize(int[] b) {
int[] arr = Arrays.stream(b).sorted().distinct().toArray();
int n = b.length;
int[] ans = new int[n];

for (int i = 0; i < n; i++) {
// b[i] must exist in arr
int order = Arrays.binarySearch(arr, b[i]);
// ans[i] >= 1 because index 0 of BIT must be left empty. In fact, you don't need to add 1.
ans[i] = order + 1;
}
return ans;
}

If not removing duplicates, we use std::lower_bound instead of std::binary_search in C++. However, Java does not have such a function, so we must remove duplicates.

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import java.util.Arrays;
import java.util.stream.IntStream;

public class Solution {
static class BIT {
long[] arr;

BIT(int n) {
arr = new long[n + 1];
}

int lsb(int i) {
return i & -i;
}

void setMax(int i, long val) {
// Index i affects all subsequent intervals, so we add lsb
for (; i < arr.length; i += lsb(i)) {
arr[i] = Math.max(arr[i], val);
}
}

long getMax(int i) {
long ans = Long.MIN_VALUE;
// Prefix interval, so we subtract lsb
for (; i > 0; i -= lsb(i)) {
ans = Math.max(ans, arr[i]);
}
return ans;
}

}

/**
* @return ans[i] indicates the ranking of b[i] after sorting
*/
private int[] discretize(int[] b) {
int[] arr = Arrays.stream(b).sorted().distinct().toArray();
int n = b.length;
int[] ans = new int[n];

for (int i = 0; i < n; i++) {
// b[i] must exist in arr
int order = Arrays.binarySearch(arr, b[i]);
// ans[i] >= 1 because index 0 of BIT must be left empty. In fact, you don't need to add 1.
ans[i] = order + 1;
}
return ans;
}

public long maxBalancedSubsequenceSum(int[] nums) {
int n = nums.length;
// b[i] = nums[i] - i
int[] b = IntStream.range(0, n).map(i -> nums[i] - i).toArray();
// Discretization
int[] arr = discretize(b);
BIT tree = new BIT(n + 1);

long ans = Long.MIN_VALUE;
for (int j = 0; j < n; j++) {
// b[j] is the k-th smallest element
int k = arr[j];
// Range query for the maximum value
long dp_i = Math.max(tree.getMax(k), 0);
// Point update
long dp_j = dp_i + nums[j];
tree.setMax(k, dp_j);

ans = Math.max(ans, dp_j);
}
return ans;
}
}

References