Skip to content

Commit

Permalink
Merge pull request duckdb#9653 from carlopi/serialize_quantiles
Browse files Browse the repository at this point in the history
Serialize decimal quantiles
  • Loading branch information
Mytherin authored Nov 13, 2023
2 parents 702d22f + a1ef8e7 commit fb59632
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 26 deletions.
39 changes: 39 additions & 0 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "duckdb/common/types/timestamp.hpp"
#include "duckdb/common/types/vector.hpp"
#include "duckdb/common/types/vector_buffer.hpp"
#include "duckdb/core_functions/aggregate/quantile_enum.hpp"
#include "duckdb/execution/index/art/art.hpp"
#include "duckdb/execution/index/art/node.hpp"
#include "duckdb/execution/operator/scan/csv/base_csv_reader.hpp"
Expand Down Expand Up @@ -4571,6 +4572,44 @@ ProfilerPrintFormat EnumUtil::FromString<ProfilerPrintFormat>(const char *value)
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

template<>
const char* EnumUtil::ToChars<QuantileSerializationType>(QuantileSerializationType value) {
switch(value) {
case QuantileSerializationType::NON_DECIMAL:
return "NON_DECIMAL";
case QuantileSerializationType::DECIMAL_DISCRETE:
return "DECIMAL_DISCRETE";
case QuantileSerializationType::DECIMAL_DISCRETE_LIST:
return "DECIMAL_DISCRETE_LIST";
case QuantileSerializationType::DECIMAL_CONTINUOUS:
return "DECIMAL_CONTINUOUS";
case QuantileSerializationType::DECIMAL_CONTINUOUS_LIST:
return "DECIMAL_CONTINUOUS_LIST";
default:
throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value));
}
}

template<>
QuantileSerializationType EnumUtil::FromString<QuantileSerializationType>(const char *value) {
if (StringUtil::Equals(value, "NON_DECIMAL")) {
return QuantileSerializationType::NON_DECIMAL;
}
if (StringUtil::Equals(value, "DECIMAL_DISCRETE")) {
return QuantileSerializationType::DECIMAL_DISCRETE;
}
if (StringUtil::Equals(value, "DECIMAL_DISCRETE_LIST")) {
return QuantileSerializationType::DECIMAL_DISCRETE_LIST;
}
if (StringUtil::Equals(value, "DECIMAL_CONTINUOUS")) {
return QuantileSerializationType::DECIMAL_CONTINUOUS;
}
if (StringUtil::Equals(value, "DECIMAL_CONTINUOUS_LIST")) {
return QuantileSerializationType::DECIMAL_CONTINUOUS_LIST;
}
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

template<>
const char* EnumUtil::ToChars<QueryNodeType>(QueryNodeType value) {
switch(value) {
Expand Down
109 changes: 84 additions & 25 deletions src/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/core_functions/aggregate/holistic_functions.hpp"
#include "duckdb/core_functions/aggregate/quantile_enum.hpp"
#include "duckdb/planner/expression.hpp"
#include "duckdb/common/operator/cast_operators.hpp"
#include "duckdb/common/operator/abs.hpp"
Expand Down Expand Up @@ -442,6 +443,8 @@ inline Value QuantileAbs(const Value &v) {
}
}

void BindQuantileInner(AggregateFunction &function, const LogicalType &type, QuantileSerializationType quantile_type);

struct QuantileBindData : public FunctionData {
QuantileBindData() {
}
Expand Down Expand Up @@ -507,15 +510,59 @@ struct QuantileBindData : public FunctionData {
deserializer.ReadProperty(100, "quantiles", raw);
deserializer.ReadProperty(101, "order", result->order);
deserializer.ReadProperty(102, "desc", result->desc);
QuantileSerializationType deserialization_type;
deserializer.ReadPropertyWithDefault(103, "quantile_type", deserialization_type,
QuantileSerializationType::NON_DECIMAL);

if (deserialization_type != QuantileSerializationType::NON_DECIMAL) {
LogicalType arg_type;
deserializer.ReadProperty(104, "logical_type", arg_type);

BindQuantileInner(function, arg_type, deserialization_type);
}

for (const auto &r : raw) {
result->quantiles.emplace_back(QuantileValue(r));
}
return std::move(result);
}

static void SerializeDecimal(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {
throw NotImplementedException("FIXME: serializing quantiles with decimals is not supported right now");
static void SerializeDecimalDiscrete(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {
Serialize(serializer, bind_data_p, function);

serializer.WritePropertyWithDefault<QuantileSerializationType>(
103, "quantile_type", QuantileSerializationType::DECIMAL_DISCRETE, QuantileSerializationType::NON_DECIMAL);
serializer.WriteProperty(104, "logical_type", function.arguments[0]);
}
static void SerializeDecimalDiscreteList(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {

Serialize(serializer, bind_data_p, function);

serializer.WritePropertyWithDefault<QuantileSerializationType>(103, "quantile_type",
QuantileSerializationType::DECIMAL_DISCRETE_LIST,
QuantileSerializationType::NON_DECIMAL);
serializer.WriteProperty(104, "logical_type", function.arguments[0]);
}
static void SerializeDecimalContinuous(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {
Serialize(serializer, bind_data_p, function);

serializer.WritePropertyWithDefault<QuantileSerializationType>(103, "quantile_type",
QuantileSerializationType::DECIMAL_CONTINUOUS,
QuantileSerializationType::NON_DECIMAL);
serializer.WriteProperty(104, "logical_type", function.arguments[0]);
}
static void SerializeDecimalContinuousList(Serializer &serializer, const optional_ptr<FunctionData> bind_data_p,
const AggregateFunction &function) {

Serialize(serializer, bind_data_p, function);

serializer.WritePropertyWithDefault<QuantileSerializationType>(
103, "quantile_type", QuantileSerializationType::DECIMAL_CONTINUOUS_LIST,
QuantileSerializationType::NON_DECIMAL);
serializer.WriteProperty(104, "logical_type", function.arguments[0]);
}

vector<QuantileValue> quantiles;
Expand Down Expand Up @@ -1232,7 +1279,7 @@ unique_ptr<FunctionData> BindMedianDecimal(ClientContext &context, AggregateFunc

function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type);
function.name = "median";
function.serialize = QuantileBindData::SerializeDecimal;
function.serialize = QuantileBindData::SerializeDecimalDiscrete;
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
return bind_data;
Expand Down Expand Up @@ -1283,50 +1330,62 @@ unique_ptr<FunctionData> BindQuantile(ClientContext &context, AggregateFunction
return make_uniq<QuantileBindData>(quantiles);
}

void BindQuantileInner(AggregateFunction &function, const LogicalType &type, QuantileSerializationType quantile_type) {
switch (quantile_type) {
case QuantileSerializationType::DECIMAL_DISCRETE:
function = GetDiscreteQuantileAggregateFunction(type);
function.serialize = QuantileBindData::SerializeDecimalDiscrete;
function.name = "quantile_disc";
break;
case QuantileSerializationType::DECIMAL_DISCRETE_LIST:
function = GetDiscreteQuantileListAggregateFunction(type);
function.serialize = QuantileBindData::SerializeDecimalDiscreteList;
function.name = "quantile_disc";
break;
case QuantileSerializationType::DECIMAL_CONTINUOUS:
function = GetContinuousQuantileAggregateFunction(type);
function.serialize = QuantileBindData::SerializeDecimalContinuous;
function.name = "quantile_cont";
break;
case QuantileSerializationType::DECIMAL_CONTINUOUS_LIST:
function = GetContinuousQuantileListAggregateFunction(type);
function.serialize = QuantileBindData::SerializeDecimalContinuousList;
function.name = "quantile_cont";
break;
case QuantileSerializationType::NON_DECIMAL:
throw SerializationException("NON_DECIMAL is not a valid quantile_type for BindQuantileInner");
}
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
}

unique_ptr<FunctionData> BindDiscreteQuantileDecimal(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindQuantile(context, function, arguments);
function = GetDiscreteQuantileAggregateFunction(arguments[0]->return_type);
function.name = "quantile_disc";
function.serialize = QuantileBindData::SerializeDecimal;
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_DISCRETE);
return bind_data;
}

unique_ptr<FunctionData> BindDiscreteQuantileDecimalList(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindQuantile(context, function, arguments);
function = GetDiscreteQuantileListAggregateFunction(arguments[0]->return_type);
function.name = "quantile_disc";
function.serialize = QuantileBindData::SerializeDecimal;
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_DISCRETE_LIST);
return bind_data;
}

unique_ptr<FunctionData> BindContinuousQuantileDecimal(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindQuantile(context, function, arguments);
function = GetContinuousQuantileAggregateFunction(arguments[0]->return_type);
function.name = "quantile_cont";
function.serialize = QuantileBindData::SerializeDecimal;
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_CONTINUOUS);
return bind_data;
}

unique_ptr<FunctionData> BindContinuousQuantileDecimalList(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
auto bind_data = BindQuantile(context, function, arguments);
function = GetContinuousQuantileListAggregateFunction(arguments[0]->return_type);
function.name = "quantile_cont";
function.serialize = QuantileBindData::SerializeDecimal;
function.deserialize = QuantileBindData::Deserialize;
function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
BindQuantileInner(function, arguments[0]->return_type, QuantileSerializationType::DECIMAL_CONTINUOUS_LIST);
return bind_data;
}

static bool CanInterpolate(const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::INTERVAL:
Expand Down
2 changes: 1 addition & 1 deletion src/function/function_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ AggregateFunction AggregateFunctionSet::GetFunctionByArguments(ClientContext &co
}
bool is_prefix = true;
for (idx_t k = 0; k < arguments.size(); k++) {
if (arguments[k] != func.arguments[k]) {
if (arguments[k].id() != func.arguments[k].id()) {
is_prefix = false;
break;
}
Expand Down
8 changes: 8 additions & 0 deletions src/include/duckdb/common/enum_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ enum class PreparedParamType : uint8_t;

enum class ProfilerPrintFormat : uint8_t;

enum class QuantileSerializationType : uint8_t;

enum class QueryNodeType : uint8_t;

enum class QueryResultType : uint8_t;
Expand Down Expand Up @@ -557,6 +559,9 @@ const char* EnumUtil::ToChars<PreparedParamType>(PreparedParamType value);
template<>
const char* EnumUtil::ToChars<ProfilerPrintFormat>(ProfilerPrintFormat value);

template<>
const char* EnumUtil::ToChars<QuantileSerializationType>(QuantileSerializationType value);

template<>
const char* EnumUtil::ToChars<QueryNodeType>(QueryNodeType value);

Expand Down Expand Up @@ -948,6 +953,9 @@ PreparedParamType EnumUtil::FromString<PreparedParamType>(const char *value);
template<>
ProfilerPrintFormat EnumUtil::FromString<ProfilerPrintFormat>(const char *value);

template<>
QuantileSerializationType EnumUtil::FromString<QuantileSerializationType>(const char *value);

template<>
QueryNodeType EnumUtil::FromString<QueryNodeType>(const char *value);

Expand Down
21 changes: 21 additions & 0 deletions src/include/duckdb/core_functions/aggregate/quantile_enum.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// duckdb/core_functions/aggregate/quantile_enum.hpp
//
//
//===----------------------------------------------------------------------===//

#pragma once

namespace duckdb {

enum class QuantileSerializationType : uint8_t {
NON_DECIMAL = 0,
DECIMAL_DISCRETE,
DECIMAL_DISCRETE_LIST,
DECIMAL_CONTINUOUS,
DECIMAL_CONTINUOUS_LIST
};

}
51 changes: 51 additions & 0 deletions test/sql/aggregate/quantile_fun.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# name: test/sql/aggregate/quantile_fun.test
# description: Test pg_sequence function
# group: [aggregate]

require tpch

# scalar quantiles
statement ok
create table quantiles as select range r, random() FROM range(100) union all values (NULL, 0.1), (NULL, 0.5), (NULL, 0.9) order by 2;

statement ok
CALL dbgen(sf=0.001);

statement ok
PRAGMA enable_verification;

statement ok
PRAGMA verify_external;

statement ok
SELECT quantile_disc(0.1::decimal(4,1), [0.1, 0.5, 0.9]);

statement ok
SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY "l_extendedprice") FROM lineitem;

statement ok
SET default_null_order='nulls_first';

foreach type decimal(4,1) decimal(8,1) decimal(12,1) decimal(18,1) decimal(24,1)

query I
SELECT quantile_disc(r::${type}, 0.1) FROM quantiles
----
9.0

query I
SELECT quantile_disc(r::${type}, [0.1, 0.5, 0.9]) FROM quantiles
----
[9.0, 49.0, 89.0]

query I
SELECT quantile_cont(r::${type}, 0.15) FROM quantiles
----
14.8

query I
SELECT quantile_cont(r::${type}, [0.15, 0.5, 0.9]) FROM quantiles
----
[14.8, 49.5, 89.1]

endloop

0 comments on commit fb59632

Please sign in to comment.