Skip to content

Commit

Permalink
Fix map_top_n to use keys to break ties on NULL values. prestodb#22778
Browse files Browse the repository at this point in the history
  • Loading branch information
kgpai committed Jun 10, 2024
1 parent f661a1c commit 48560a8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static String mapKeysExists()
@SqlType("map(K, V)")
public static String mapTopN()
{
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) || map_entries(map_filter(input, (k, v) -> v is null)), 1, n)))";
return "RETURN IF(n < 0, fail('n must be greater than or equal to 0'), map_from_entries(slice(array_sort(map_entries(map_filter(input, (k, v) -> v is not null)), (x, y) -> IF(x[2] < y[2], 1, IF(x[2] = y[2], IF(x[1] < y[1], 1, -1), -1))) "
+ "|| ARRAY_SORT(MAP_ENTRIES(MAP_FILTER(input, (k, v) -> v IS NULL)), (x, y) -> IF( x[1] < y[1], 1, -1)), 1, n)))";
}

@SqlInvokedScalarFunction(value = "map_top_n_keys", deterministic = true, calledOnNullInput = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.HashMap;
import java.util.Map;

import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
Expand Down Expand Up @@ -95,6 +98,25 @@ public void testEmpty()
public void testNull()
{
assertFunction("MAP_TOP_N(NULL, 1)", mapType(UNKNOWN, UNKNOWN), null);

// If values are null , then use keys to break ties.
Map<Integer, Integer> expected = new HashMap<Integer, Integer>() {{
put(4, 4);
put(3, 1);
put(5, null);
}};

assertFunction("MAP_TOP_N(MAP(ARRAY[1, 2, 3, 4, 5], ARRAY[NULL, NULL, 1, 4, NULL]), 3)", mapType(INTEGER, INTEGER),
expected);

Map<String, Integer> expectedStringKey = new HashMap<String, Integer>() {{
put("ef", 6);
put("cd", 4);
put("ab", -1);
put("hi", null);
}};
assertFunction("MAP_TOP_N(MAP(ARRAY['ab', 'bc', 'ef', 'cd', 'hi'], ARRAY[-1, NULL, 6, 4, NULL]), 4)", mapType(createVarcharType(2), INTEGER),
expectedStringKey);
}

@Test
Expand Down

0 comments on commit 48560a8

Please sign in to comment.