Skip to content

Commit

Permalink
Add map_normalize Presto function (facebookincubator#9086)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookincubator#9086

Differential Revision: D54918766
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Mar 14, 2024
1 parent 704113a commit b10046f
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 0 deletions.
14 changes: 14 additions & 0 deletions velox/docs/functions/presto/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ Map Functions

SELECT map_from_entries(ARRAY[(1, 'x'), (2, 'y')]); -- {1 -> 'x', 2 -> 'y'}

.. function:: map_normalize(map(varchar,double)) -> map(varchar,double)

Returns the map with the same keys but all non-null values scaled proportionally
so that the sum of values becomes 1. Map entries with null values remain unchanged.

When total sum of non-null values is zero, null values remain null,
zero, NaN, Infinity and -Infinity values become NaN,
positive values become Infinity, negative values become -Infinity.::

SELECT map_normalize(map(array['a', 'b', 'c'], array[1, 4, 5])); -- {a=0.1, b=0.4, c=0.5}
SELECT map_normalize(map(array['a', 'b', 'c', 'd'], array[1, null, 4, 5])); -- {a=0.1, b=null, c=0.4, d=0.5}
SELECT map_normalize(map(array['a', 'b', 'c'], array[1, 0, -1])); -- {a=Infinity, b=NaN, c=-Infinity}


.. function:: map_subset(map(K,V), array(k)) -> map(K,V)

Constructs a map from those entries of ``map`` for which the key is in the array given::
Expand Down
54 changes: 54 additions & 0 deletions velox/functions/prestosql/MapNormalize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/ComplexViewTypes.h"
#include "velox/functions/Udf.h"

namespace facebook::velox::functions {

template <typename TExec>
struct MapNormalizeFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

void call(
out_type<Map<Varchar, double>>& out,
const arg_type<Map<Varchar, double>>& inputMap) {
double totalSum = 0.0;
for (const auto& entry : inputMap) {
if (entry.second.has_value()) {
totalSum += entry.second.value();
}
}

// totalSum can be zero, but that's OK. See
// https://github.com/prestodb/presto/issues/22209 for Presto Java
// semantics.

for (const auto& entry : inputMap) {
if (!entry.second.has_value()) {
auto& keyWriter = out.add_null();
keyWriter.copy_from(entry.first);
} else {
auto [keyWriter, valueWriter] = out.add_item();
keyWriter.copy_from(entry.first);
valueWriter = entry.second.value() / totalSum;
}
}
}
};

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/prestosql/MapNormalize.h"
#include "velox/functions/prestosql/MapSubset.h"
#include "velox/functions/prestosql/MapTopN.h"
#include "velox/functions/prestosql/MultimapFromEntries.h"
Expand Down Expand Up @@ -77,6 +78,11 @@ void registerMapFunctions(const std::string& prefix) {
Map<Generic<T1>, Generic<T2>>,
Map<Generic<T1>, Generic<T2>>,
Array<Generic<T1>>>({prefix + "map_subset"});

registerFunction<
MapNormalizeFunction,
Map<Varchar, double>,
Map<Varchar, double>>({prefix + "map_normalize"});
}

void registerMapAllowingDuplicates(
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_executable(
MapEntriesTest.cpp
MapFilterTest.cpp
MapFromEntriesTest.cpp
MapNormalizeTest.cpp
MapTopNTest.cpp
MultimapFromEntriesTest.cpp
MapKeysAndValuesTest.cpp
Expand Down
51 changes: 51 additions & 0 deletions velox/functions/prestosql/tests/MapNormalizeTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions {
namespace {

class MapNormalizeTest : public test::FunctionBaseTest {};

TEST_F(MapNormalizeTest, basic) {
auto data = makeRowVector({
makeMapVectorFromJson<std::string, double>({
"{\"a\": 1.0, \"b\": 2.0, \"c\": 2.0}",
"{\"a\": 2.0, \"b\": 2.0, \"c\": null, \"d\": 1.0}",
"{\"a\": 1.0, \"b\": -1.0, \"c\": null, \"d\": 0.0}",
"{\"a\": null, \"b\": null}",
"{}",
}),
});

auto result = evaluate("map_normalize(c0)", data);

auto expected = makeMapVectorFromJson<std::string, double>({
"{\"a\": 0.2, \"b\": 0.4, \"c\": 0.4}",
"{\"a\": 0.4, \"b\": 0.4, \"c\": null, \"d\": 0.2}",
"{\"a\": Infinity, \"b\": -Infinity, \"c\": null, \"d\": NaN}",
"{\"a\": null, \"b\": null}",
"{}",
});

assertEqualVectors(expected, result);
}

} // namespace
} // namespace facebook::velox::functions

0 comments on commit b10046f

Please sign in to comment.