Skip to content

Commit

Permalink
Fix map_top_n to use keys to break ties on NULL values. prestodb#22778
Browse files Browse the repository at this point in the history
  • Loading branch information
kgpai authored and root committed Aug 8, 2024
1 parent a230dae commit 142928e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
4 changes: 2 additions & 2 deletions presto-docs/src/main/sphinx/functions/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ Map Functions

.. function:: map_top_n(x(K,V), n) -> map(K, V)

Truncates map items. Keeps only the top N elements by value.
``n`` must be a non-negative integer.::
Truncates map items. Keeps only the top ``n`` elements by value. Keys are used to break ties with the max key being chosen. Both keys and values should be orderable.
``n`` must be a non-negative integer. ::

SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static String mapKeysExists()
@SqlType("map(K, V)")
public static String mapTopN()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) || map_entries(map_filter(input, (k, v) -> v is null)), 1, n)))";
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) "
+ "|| ARRAY_SORT(MAP_ENTRIES(MAP_FILTER(input, (k, v) -> v IS NULL)), (x, y) -> IF( x[1] < y[1], 1, -1)), 1, n)))";
}

@SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.HashMap;
import java.util.Map;

import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
Expand Down Expand Up @@ -95,6 +98,25 @@ public void testEmpty()
public void testNull()
{
assertFunction("MAP_TOP_N(NULL, 1)", mapType(UNKNOWN, UNKNOWN), null);

// If values are null, then use keys to break ties.
Map<Integer, Integer> expected = new HashMap<Integer, Integer>() {{
put(4, 4);
put(3, 1);
put(5, null);
}};

assertFunction("MAP_TOP_N(MAP(ARRAY[1, 2, 3, 4, 5], ARRAY[NULL, NULL, 1, 4, NULL]), 3)", mapType(INTEGER, INTEGER),
expected);

Map<String, Integer> expectedStringKey = new HashMap<String, Integer>() {{
put("ef", 6);
put("cd", 4);
put("ab", -1);
put("hi", null);
}};
assertFunction("MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'ef', 'cd', 'hi'], ARRAY[-1, NULL, 6, 4, NULL]), 4)", mapType(createVarcharType(2), INTEGER),
expectedStringKey);
}

@Test
Expand Down

0 comments on commit 142928e

Please sign in to comment.