Time for some Funky SQL: Prefix Sum Calculation
This Stack Overflow question has yet again nerd-sniped me:
[finding the] maximum element in the array that would result from performing all M operations
Here’s the question by John that was looking for a Java solution:
With an array of N elements which are initialized to 0. we are given a sequence of M operations of the sort
(p; q; r)
. The operation(p; q; r)
signifies that the integer r should be added to all array elementsA[p];A[p + 1]; : : : ;A[q]
. You are to output the maximum element in the array that would result from performing all M operations. There is a naive solution that simply performs all operations and then returns the maximum value, that takesO(MN)
time. We are looking for a more efficient algorithm.
Interesting. Indeed, a naive solution would just perform all the operations as requested. Another naive but less naive solution would transform the operations into signals of the form (x; y)
for all (p; r)
and for all (q + 1; -r)
. In other words, we could implement the solution I had presented trivially as such:
// This is just a utility class to model the ops class Operation { final int p; final int q; final int r; Operation(int p, int q, int r) { this.p = p; this.q = q; this.r = r; } } // These are some example ops Operation[] ops = { new Operation(4, 12, 2), new Operation(2, 8, 3), new Operation(6, 7, 1), new Operation(3, 7, 2) }; // Here, we're calculating the min and max // values for the combined values of p and q IntSummaryStatistics stats = Stream .of(ops) .flatMapToInt(op -> IntStream.of(op.p, op.q)) .summaryStatistics(); // Create an array for all the required elements using // the min value as "offset" int[] array = new int[stats.getMax() - stats.getMin()]; // Put +r and -r "signals" into the array for each op for (Operation op : ops) { int lo = op.p - stats.getMin(); int hi = op.q + 1 - stats.getMin(); if (lo >= 0) array[lo] = array[lo] + op.r; if (hi < array.length) array[hi] = array[hi] - op.r; } // Now, calculate the prefix sum sequentially in a // trivial loop int maxIndex = Integer.MIN_VALUE; int maxR = Integer.MIN_VALUE; int r = 0; for (int i = 0; i < array.length; i++) { r = r + array[i]; System.out.println((i + stats.getMin()) + ":" + r); if (r > maxR) { maxIndex = i + stats.getMin(); maxR = r; } } System.out.println("---"); System.out.println(maxIndex + ":" + maxR);
The above program would print out:
2:3 3:5 4:7 5:7 6:8 7:8 8:5 9:2 10:2 11:2 --- 6:8
So, the maximum value is generated at position 6, and the value is 8.
Faster calculation in Java 8
This can be calculated faster using Java 8’s new Arrays.parallelPrefix() operation. Instead of the loop in the end, just write:
Arrays.parallelPrefix(array, Integer::sum); System.out.println( Arrays.stream(array).parallel().max());
Which is awesome, as it can run faster than the sequential O(M+N)
solution. Read up about prefix sums here.
Now show me the promised SQL code
In SQL, the naive sequential and linear complexity solution can easily be re-implemented, and I’m showing a solution for PostgreSQL.
How can we do it? We’re using a couple of features here. First off, we’re using common table expressions (also known as the WITH
clause). We’re using these to declare table variables. The first variable is the op
table, which contains our operation instructions, like in Java:
WITH op (p, q, r) AS ( VALUES (4, 12, 2), (2, 8, 3), (6, 7, 1), (3, 7, 2) ), ...
This is trivial. We’re essentially just generating a couple of example values.
The second table variable is the signal table, where we use the previously described optimisation of putting a +r
signal at all p
positions, and a -r
signal at all q + 1
positions:
WITH ..., signal(x, r) AS ( SELECT p, r FROM op UNION ALL SELECT q + 1, -r FROM op ) ...
When you run:
SELECT * FROM signal ORDER BY x
you would simply get:
x r ------ 2 3 3 2 4 2 6 1 8 -2 8 -1 9 -3 13 -2
All we need to do now is calculate a running total (which is essentially the same as a prefix sum) as follows:
SELECT x, SUM(r) OVER (ORDER BY x) FROM signal ORDER BY x
x r ------ 2 3 3 5 4 7 6 8 8 5 8 5 9 2 13 0
Now just find the max value for r
, and we’re all set. We’ll take the shortcut by using ORDER BY
and LIMIT
:
SELECT x, SUM(r) OVER (ORDER BY x) AS s FROM signal ORDER BY s DESC LIMIT 1
And we’re back with:
x r ------ 6 8
Perfect! Here’s the full query:
WITH op (p, q, r) AS ( VALUES (4, 12, 2), (2, 8, 3), (6, 7, 1), (3, 7, 2) ), signal(x, r) AS ( SELECT p, r FROM op UNION ALL SELECT q + 1, -r FROM op ) SELECT x, SUM(r) OVER (ORDER BY x) AS s FROM signal ORDER BY s DESC LIMIT 1
Can you beat the conciseness of this SQL solution? I bet you can’t. Challengers shall write alternatives in the comment section.
Reference: | Time for some Funky SQL: Prefix Sum Calculation from our JCG partner Lukas Eder at the JAVA, SQL, AND JOOQ blog. |