diff --git a/presto-docs/src/main/sphinx/functions/map.rst b/presto-docs/src/main/sphinx/functions/map.rst index dc155e54b1ed1..ae6c24987f09f 100644 --- a/presto-docs/src/main/sphinx/functions/map.rst +++ b/presto-docs/src/main/sphinx/functions/map.rst @@ -136,8 +136,8 @@ Map Functions .. function:: map_top_n(x(K,V), n) -> map(K, V) - Truncates map items. Keeps only the top N elements by value. - ``n`` must be a non-negative integer.:: + Truncates map items. Keeps only the top ``n`` elements by value. Keys are used to break ties with the max key being chosen. Both keys and values should be orderable. + ``n`` must be a non-negative integer. :: SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java index dba6659963d92..0cf9558a22989 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/sql/MapSqlFunctions.java @@ -54,7 +54,8 @@ public static String mapKeysExists() @SqlType("map(K, V)") public static String mapTopN() { - return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) || map_entries(map_filter(input, (k, v) -> v is null)), 1, n)))"; + return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) " + + "|| ARRAY_SORT(MAP_ENTRIES(MAP_FILTER(input, (k, v) -> v IS NULL)), (x, y) -> IF( x[1] < y[1], 1, -1)), 1, n)))"; } @SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java index 99ba8020b983d..b08c5bd623a4a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapTopNFunction.java @@ -20,6 +20,9 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; +import java.util.HashMap; +import java.util.Map; + import static com.facebook.presto.common.type.DecimalType.createDecimalType; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; @@ -95,6 +98,25 @@ public void testEmpty() public void testNull() { assertFunction("MAP_TOP_N(NULL, 1)", mapType(UNKNOWN, UNKNOWN), null); + + // If values are null, then use keys to break ties. + Map expected = new HashMap() {{ + put(4, 4); + put(3, 1); + put(5, null); + }}; + + assertFunction("MAP_TOP_N(MAP(ARRAY[1, 2, 3, 4, 5], ARRAY[NULL, NULL, 1, 4, NULL]), 3)", mapType(INTEGER, INTEGER), + expected); + + Map expectedStringKey = new HashMap() {{ + put("ef", 6); + put("cd", 4); + put("ab", -1); + put("hi", null); + }}; + assertFunction("MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'ef', 'cd', 'hi'], ARRAY[-1, NULL, 6, 4, NULL]), 4)", mapType(createVarcharType(2), INTEGER), + expectedStringKey); } @Test