Skip to content

Commit

Permalink
Add map_top_n Presto function
Browse files Browse the repository at this point in the history
Differential Revision: D54877549
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Mar 13, 2024
1 parent 17fc58b commit 78859a1
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 0 deletions.
9 changes: 9 additions & 0 deletions velox/docs/functions/presto/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ Map Functions

SELECT map_from_entries(ARRAY[(1, 'x'), (2, 'y')]); -- {1 -> 'x', 2 -> 'y'}

.. function:: map_top_n(map(K,V), n) -> map(K, V)

Truncates map items. Keeps only the top N elements by value.
``n`` must be a non-negative BIGINT value.::

SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[2, 3, 1]), 2) --- {'b' -> 3, 'a' -> 2}
SELECT map_top_n(map(ARRAY['a', 'b', 'c'], ARRAY[NULL, 3, NULL]), 2) --- {'b' -> 3, 'a' -> NULL}


.. function:: multimap_from_entries(array(row(K,V))) -> map(K,array(V))

Returns a multimap created from the given array of entries. Each key can be associated with multiple values. ::
Expand Down
92 changes: 92 additions & 0 deletions velox/functions/prestosql/MapTopN.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/ComplexViewTypes.h"
#include "velox/functions/Udf.h"

namespace facebook::velox::functions {

template <typename TExec>
struct MapTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

template <typename It>
struct Compare {
bool operator()(const It& l, const It& r) const {
static const CompareFlags flags{
false /*nullsFirst*/,
true /*ascending*/,
false /*equalsOnly*/,
CompareFlags::NullHandlingMode::kNullAsIndeterminate};

if (l->second.has_value() && r->second.has_value()) {
return l->second.value().compare(r->second.value(), flags) > 0;
}

return l->second.has_value();
}
};

void call(
out_type<Map<Generic<T1>, Generic<T2>>>& out,
const arg_type<Map<Generic<T1>, Generic<T2>>>& inputMap,
int64_t n) {
VELOX_USER_CHECK_GE(n, 0, "n must be greater than or equal to 0")

if (n == 0) {
return;
}

if (n >= inputMap.size()) {
out.copy_from(inputMap);
return;
}

using It = typename arg_type<Map<Generic<T1>, Generic<T2>>>::Iterator;

Compare<It> comparator;

std::priority_queue<It, std::vector<It>, Compare<It>> topEntries(
comparator);

for (auto it = inputMap.begin(); it != inputMap.end(); ++it) {
if (topEntries.size() < n) {
topEntries.push(it);
} else if (comparator(it, topEntries.top())) {
topEntries.pop();
topEntries.push(it);
}
}

while (!topEntries.empty()) {
auto it = topEntries.top();

if (!it->second.has_value()) {
auto& keyWriter = out.add_null();
keyWriter.copy_from(it->first);
} else {
auto [keyWriter, valueWriter] = out.add_item();
keyWriter.copy_from(it->first);
valueWriter.copy_from(it->second.value());
}

topEntries.pop();
}
}
};

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/prestosql/MapTopN.h"
#include "velox/functions/prestosql/MultimapFromEntries.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -48,6 +49,12 @@ void registerMapFunctions(const std::string& prefix) {
MultimapFromEntriesFunction,
Map<Generic<T1>, Array<Generic<T2>>>,
Array<Row<Generic<T1>, Generic<T2>>>>({prefix + "multimap_from_entries"});

registerFunction<
MapTopNFunction,
Map<Generic<T1>, Generic<T2>>,
Map<Generic<T1>, Generic<T2>>,
int64_t>({prefix + "map_top_n"});
}

void registerMapAllowingDuplicates(
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_executable(
MapEntriesTest.cpp
MapFilterTest.cpp
MapFromEntriesTest.cpp
MapTopNTest.cpp
MultimapFromEntriesTest.cpp
MapKeysAndValuesTest.cpp
MapMatchTest.cpp
Expand Down
70 changes: 70 additions & 0 deletions velox/functions/prestosql/tests/MapTopNTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions {
namespace {

class MapTopNTest : public test::FunctionBaseTest {};

TEST_F(MapTopNTest, basic) {
auto data = makeRowVector({
makeMapVectorFromJson<int32_t, int64_t>({
"{1:3, 2:5, 3:1, 4:4, 5:2}",
"{1:3, 2:5, 3:null, 4:4, 5:2}",
"{1:null, 2:null, 3:1, 4:4, 5:null}",
"{1:10, 2:7, 3:11, 5:4}",
"{1:10, 2:7, 3:0}",
"{1:null, 2:10}",
"{}",
"{1:null, 2:null, 3:null, 4:null}",
}),
});

auto result = evaluate("map_top_n(c0, 3)", data);

auto expected = makeMapVectorFromJson<int32_t, int64_t>({
"{2:5, 4:4, 1:3}",
"{2:5, 4:4, 1:3}",
"{4:4, 3:1, 5:null}",
"{3:11, 1:10, 2:7}",
"{1:10, 2:7, 3:0}",
"{2:10, 1:null}",
"{}",
"{2:null, 3:null, 4:null}",
});

assertEqualVectors(expected, result);

// n = 0. Expect empty maps.
result = evaluate("map_top_n(c0, 0)", data);

expected = makeMapVectorFromJson<int32_t, int64_t>(
{"{}", "{}", "{}", "{}", "{}", "{}", "{}", "{}"});

assertEqualVectors(expected, result);

// n is negative. Expect an error.
VELOX_ASSERT_THROW(
evaluate("map_top_n(c0, -1)", data),
"n must be greater than or equal to 0");
}

} // namespace
} // namespace facebook::velox::functions

0 comments on commit 78859a1

Please sign in to comment.