diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index 3658b55faeb6..b91addaa5b5b 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -70,6 +70,20 @@ Map Functions SELECT map_from_entries(ARRAY[(1, 'x'), (2, 'y')]); -- {1 -> 'x', 2 -> 'y'} +.. function:: map_normalize(map(varchar,double)) -> map(varchar,double) + + Returns the map with the same keys but all non-null values scaled proportionally + so that the sum of values becomes 1. Map entries with null values remain unchanged. + + When total sum of non-null values is zero, null values remain null, + zero, NaN, Infinity and -Infinity values become NaN, + positive values become Infinity, negative values become -Infinity.:: + + SELECT map_normalize(map(array['a', 'b', 'c'], array[1, 4, 5])); -- {a=0.1, b=0.4, c=0.5} + SELECT map_normalize(map(array['a', 'b', 'c', 'd'], array[1, null, 4, 5])); -- {a=0.1, b=null, c=0.4, d=0.5} + SELECT map_normalize(map(array['a', 'b', 'c'], array[1, 0, -1])); -- {a=Infinity, b=NaN, c=-Infinity} + + .. function:: map_subset(map(K,V), array(k)) -> map(K,V) Constructs a map from those entries of ``map`` for which the key is in the array given:: diff --git a/velox/functions/prestosql/MapNormalize.h b/velox/functions/prestosql/MapNormalize.h new file mode 100644 index 000000000000..edf7cede724f --- /dev/null +++ b/velox/functions/prestosql/MapNormalize.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/ComplexViewTypes.h" +#include "velox/functions/Udf.h" + +namespace facebook::velox::functions { + +template +struct MapNormalizeFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + void call( + out_type>& out, + const arg_type>& inputMap) { + double totalSum = 0.0; + for (const auto& entry : inputMap) { + if (entry.second.has_value()) { + totalSum += entry.second.value(); + } + } + + // totalSum can be zero, but that's OK. See + // https://github.com/prestodb/presto/issues/22209 for Presto Java + // semantics. + + for (const auto& entry : inputMap) { + if (!entry.second.has_value()) { + auto& keyWriter = out.add_null(); + keyWriter.copy_from(entry.first); + } else { + auto [keyWriter, valueWriter] = out.add_item(); + keyWriter.copy_from(entry.first); + valueWriter = entry.second.value() / totalSum; + } + } + } +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp index 0b73113ffd9a..fca2990a4aa7 100644 --- a/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MapFunctionsRegistration.cpp @@ -17,6 +17,7 @@ #include "velox/expression/VectorFunction.h" #include "velox/functions/Registerer.h" #include "velox/functions/lib/MapConcat.h" +#include "velox/functions/prestosql/MapNormalize.h" #include "velox/functions/prestosql/MapSubset.h" #include "velox/functions/prestosql/MapTopN.h" #include "velox/functions/prestosql/MultimapFromEntries.h" @@ -77,6 +78,11 @@ void registerMapFunctions(const std::string& prefix) { Map, Generic>, Map, Generic>, Array>>({prefix + "map_subset"}); + + registerFunction< + MapNormalizeFunction, + Map, + Map>({prefix + "map_normalize"}); } void registerMapAllowingDuplicates( diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index c0df56c5ae1b..d0687a849ea0 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -66,6 +66,7 @@ add_executable( MapEntriesTest.cpp MapFilterTest.cpp MapFromEntriesTest.cpp + MapNormalizeTest.cpp MapTopNTest.cpp MultimapFromEntriesTest.cpp MapKeysAndValuesTest.cpp diff --git a/velox/functions/prestosql/tests/MapNormalizeTest.cpp b/velox/functions/prestosql/tests/MapNormalizeTest.cpp new file mode 100644 index 000000000000..db3eb3902935 --- /dev/null +++ b/velox/functions/prestosql/tests/MapNormalizeTest.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions { +namespace { + +class MapNormalizeTest : public test::FunctionBaseTest {}; + +TEST_F(MapNormalizeTest, basic) { + auto data = makeRowVector({ + makeMapVectorFromJson({ + "{\"a\": 1.0, \"b\": 2.0, \"c\": 2.0}", + "{\"a\": 2.0, \"b\": 2.0, \"c\": null, \"d\": 1.0}", + "{\"a\": 1.0, \"b\": -1.0, \"c\": null, \"d\": 0.0}", + "{\"a\": null, \"b\": null}", + "{}", + }), + }); + + auto result = evaluate("map_normalize(c0)", data); + + auto expected = makeMapVectorFromJson({ + "{\"a\": 0.2, \"b\": 0.4, \"c\": 0.4}", + "{\"a\": 0.4, \"b\": 0.4, \"c\": null, \"d\": 0.2}", + "{\"a\": Infinity, \"b\": -Infinity, \"c\": null, \"d\": NaN}", + "{\"a\": null, \"b\": null}", + "{}", + }); + + assertEqualVectors(expected, result); +} + +} // namespace +} // namespace facebook::velox::functions