Skip to content

Commit

Permalink
Refine function test framework (#4861)
Browse files Browse the repository at this point in the history
close #4830
  • Loading branch information
windtalker authored May 26, 2022
1 parent b34cd1d commit 973de13
Show file tree
Hide file tree
Showing 12 changed files with 521 additions and 160 deletions.
190 changes: 105 additions & 85 deletions dbms/src/Debug/astToExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,93 @@

namespace DB
{
void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb::Expr * expr, Int32 collator_id)
{
*(expr->mutable_field_type()) = columnInfoToFieldType(ci);
expr->mutable_field_type()->set_collate(collator_id);
if (!val_field.isNull())
{
WriteBufferFromOwnString ss;
switch (ci.tp)
{
case TiDB::TypeLongLong:
case TiDB::TypeLong:
case TiDB::TypeShort:
case TiDB::TypeTiny:
case TiDB::TypeInt24:
if (ci.hasUnsignedFlag())
{
expr->set_tp(tipb::ExprType::Uint64);
UInt64 val = val_field.safeGet<UInt64>();
encodeDAGUInt64(val, ss);
}
else
{
expr->set_tp(tipb::ExprType::Int64);
Int64 val = val_field.safeGet<Int64>();
encodeDAGInt64(val, ss);
}
break;
case TiDB::TypeFloat:
{
expr->set_tp(tipb::ExprType::Float32);
auto val = static_cast<Float32>(val_field.safeGet<Float64>());
encodeDAGFloat32(val, ss);
break;
}
case TiDB::TypeDouble:
{
expr->set_tp(tipb::ExprType::Float64);
Float64 val = val_field.safeGet<Float64>();
encodeDAGFloat64(val, ss);
break;
}
case TiDB::TypeString:
{
expr->set_tp(tipb::ExprType::String);
const auto & val = val_field.safeGet<String>();
encodeDAGString(val, ss);
break;
}
case TiDB::TypeNewDecimal:
{
expr->set_tp(tipb::ExprType::MysqlDecimal);
encodeDAGDecimal(val_field, ss);
break;
}
case TiDB::TypeDate:
{
expr->set_tp(tipb::ExprType::MysqlTime);
UInt64 val = val_field.safeGet<UInt64>();
encodeDAGUInt64(MyDate(val).toPackedUInt(), ss);
break;
}
case TiDB::TypeDatetime:
case TiDB::TypeTimestamp:
{
expr->set_tp(tipb::ExprType::MysqlTime);
UInt64 val = val_field.safeGet<UInt64>();
encodeDAGUInt64(MyDateTime(val).toPackedUInt(), ss);
break;
}
case TiDB::TypeTime:
{
expr->set_tp(tipb::ExprType::MysqlDuration);
Int64 val = val_field.safeGet<Int64>();
encodeDAGInt64(val, ss);
break;
}
default:
throw Exception(fmt::format("Type {} does not support literal in function unit test", getDataTypeByColumnInfo(ci)->getName()));
}
expr->set_val(ss.releaseStr());
}
else
{
expr->set_tp(tipb::ExprType::Null);
}
}

namespace
{
std::unordered_map<String, tipb::ScalarFuncSig> func_name_to_sig({
Expand Down Expand Up @@ -112,76 +199,9 @@ DAGColumnInfo toNullableDAGColumnInfo(const DAGColumnInfo & input)

void literalToPB(tipb::Expr * expr, const Field & value, uint32_t collator_id)
{
WriteBufferFromOwnString ss;
switch (value.getType())
{
case Field::Types::Which::Null:
{
expr->set_tp(tipb::Null);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeNull);
ft->set_collate(collator_id);
// Null literal expr doesn't need value.
break;
}
case Field::Types::Which::UInt64:
{
expr->set_tp(tipb::Uint64);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
ft->set_collate(collator_id);
encodeDAGUInt64(value.get<UInt64>(), ss);
break;
}
case Field::Types::Which::Int64:
{
expr->set_tp(tipb::Int64);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagNotNull);
ft->set_collate(collator_id);
encodeDAGInt64(value.get<Int64>(), ss);
break;
}
case Field::Types::Which::Float64:
{
expr->set_tp(tipb::Float64);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeFloat);
ft->set_flag(TiDB::ColumnFlagNotNull);
ft->set_collate(collator_id);
encodeDAGFloat64(value.get<Float64>(), ss);
break;
}
case Field::Types::Which::Decimal32:
case Field::Types::Which::Decimal64:
case Field::Types::Which::Decimal128:
case Field::Types::Which::Decimal256:
{
expr->set_tp(tipb::MysqlDecimal);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeNewDecimal);
ft->set_flag(TiDB::ColumnFlagNotNull);
ft->set_collate(collator_id);
encodeDAGDecimal(value, ss);
break;
}
case Field::Types::Which::String:
{
expr->set_tp(tipb::String);
auto * ft = expr->mutable_field_type();
ft->set_tp(TiDB::TypeString);
ft->set_flag(TiDB::ColumnFlagNotNull);
ft->set_collate(collator_id);
// TODO: Align with TiDB.
encodeDAGBytes(value.get<String>(), ss);
break;
}
default:
throw Exception(String("Unsupported literal type: ") + value.getTypeName(), ErrorCodes::LOGICAL_ERROR);
}
expr->set_val(ss.releaseStr());
DataTypePtr type = applyVisitor(FieldToDataType(), value);
ColumnInfo ci = reverseGetColumnInfo({"", type}, 0, Field(), true);
literalFieldToTiPBExpr(ci, value, expr, collator_id);
}

String getFunctionNameForConstantFolding(tipb::Expr * expr)
Expand Down Expand Up @@ -262,15 +282,15 @@ void identifierToPB(const DAGSchema & input, ASTIdentifier * id, tipb::Expr * ex

void astToPB(const DAGSchema & input, ASTPtr ast, tipb::Expr * expr, uint32_t collator_id, const Context & context)
{
if (ASTIdentifier * id = typeid_cast<ASTIdentifier *>(ast.get()))
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
{
identifierToPB(input, id, expr, collator_id);
}
else if (ASTFunction * func = typeid_cast<ASTFunction *>(ast.get()))
else if (auto * func = typeid_cast<ASTFunction *>(ast.get()))
{
functionToPB(input, func, expr, collator_id, context);
}
else if (ASTLiteral * lit = typeid_cast<ASTLiteral *>(ast.get()))
else if (auto * lit = typeid_cast<ASTLiteral *>(ast.get()))
{
literalToPB(expr, lit->value, collator_id);
}
Expand Down Expand Up @@ -505,7 +525,7 @@ void identifierToPB(const DAGSchema & input, ASTIdentifier * id, tipb::Expr * ex

void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unordered_set<String> & used_columns)
{
if (ASTIdentifier * id = typeid_cast<ASTIdentifier *>(ast.get()))
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
{
auto column_name = splitQualifiedName(id->getColumnName());
if (!column_name.first.empty())
Expand All @@ -526,7 +546,7 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
}
}
}
else if (ASTFunction * func = typeid_cast<ASTFunction *>(ast.get()))
else if (auto * func = typeid_cast<ASTFunction *>(ast.get()))
{
if (AggregateFunctionFactory::instance().isAggregateFunctionName(func->name))
{
Expand Down Expand Up @@ -559,7 +579,7 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
TiDB::ColumnInfo compileExpr(const DAGSchema & input, ASTPtr ast)
{
TiDB::ColumnInfo ci;
if (ASTIdentifier * id = typeid_cast<ASTIdentifier *>(ast.get()))
if (auto * id = typeid_cast<ASTIdentifier *>(ast.get()))
{
/// check column
auto ft = std::find_if(input.begin(), input.end(), [&](const auto & field) {
Expand All @@ -574,7 +594,7 @@ TiDB::ColumnInfo compileExpr(const DAGSchema & input, ASTPtr ast)
throw Exception("No such column " + id->getColumnName(), ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
ci = ft->second;
}
else if (ASTFunction * func = typeid_cast<ASTFunction *>(ast.get()))
else if (auto * func = typeid_cast<ASTFunction *>(ast.get()))
{
/// check function
String func_name_lowercase = Poco::toLower(func->name);
Expand Down Expand Up @@ -692,7 +712,7 @@ TiDB::ColumnInfo compileExpr(const DAGSchema & input, ASTPtr ast)
compileExpr(input, child_ast);
}
}
else if (ASTLiteral * lit = typeid_cast<ASTLiteral *>(ast.get()))
else if (auto * lit = typeid_cast<ASTLiteral *>(ast.get()))
{
switch (lit->value.getType())
{
Expand Down Expand Up @@ -909,7 +929,7 @@ bool TopN::toTiPBExecutor(tipb::Executor * tipb_executor, uint32_t collator_id,
tipb::TopN * topn = tipb_executor->mutable_topn();
for (const auto & child : order_columns)
{
ASTOrderByElement * elem = typeid_cast<ASTOrderByElement *>(child.get());
auto * elem = typeid_cast<ASTOrderByElement *>(child.get());
if (!elem)
throw Exception("Invalid order by element", ErrorCodes::LOGICAL_ERROR);
tipb::ByItem * by = topn->add_order_by();
Expand Down Expand Up @@ -954,7 +974,7 @@ bool Aggregation::toTiPBExecutor(tipb::Executor * tipb_executor, uint32_t collat
auto & input_schema = children[0]->output_schema;
for (const auto & expr : agg_exprs)
{
const ASTFunction * func = typeid_cast<const ASTFunction *>(expr.get());
const auto * func = typeid_cast<const ASTFunction *>(expr.get());
if (!func || !AggregateFunctionFactory::instance().isAggregateFunctionName(func->name))
throw Exception("Only agg function is allowed in select for a query with aggregation", ErrorCodes::LOGICAL_ERROR);

Expand Down Expand Up @@ -1024,7 +1044,7 @@ void Aggregation::columnPrune(std::unordered_set<String> & used_columns)
{
if (used_columns.find(func->getColumnName()) != used_columns.end())
{
const ASTFunction * agg_func = typeid_cast<const ASTFunction *>(func.get());
const auto * agg_func = typeid_cast<const ASTFunction *>(func.get());
if (agg_func != nullptr)
{
/// agg_func should not be nullptr, just double check
Expand Down Expand Up @@ -1075,7 +1095,7 @@ void Aggregation::toMPPSubPlan(size_t & executor_index, const DAGProperties & pr
/// re-construct agg_exprs and gby_exprs in final_agg
for (size_t i = 0; i < partial_agg->agg_exprs.size(); i++)
{
const ASTFunction * agg_func = typeid_cast<const ASTFunction *>(partial_agg->agg_exprs[i].get());
const auto * agg_func = typeid_cast<const ASTFunction *>(partial_agg->agg_exprs[i].get());
ASTPtr update_agg_expr = agg_func->clone();
auto * update_agg_func = typeid_cast<ASTFunction *>(update_agg_expr.get());
if (agg_func->name == "count")
Expand Down Expand Up @@ -1368,7 +1388,7 @@ ExecutorPtr compileTopN(ExecutorPtr input, size_t & executor_index, ASTPtr order
std::vector<ASTPtr> order_columns;
for (const auto & child : order_exprs->children)
{
ASTOrderByElement * elem = typeid_cast<ASTOrderByElement *>(child.get());
auto * elem = typeid_cast<ASTOrderByElement *>(child.get());
if (!elem)
throw Exception("Invalid order by element", ErrorCodes::LOGICAL_ERROR);
order_columns.push_back(child);
Expand Down Expand Up @@ -1399,7 +1419,7 @@ ExecutorPtr compileAggregation(ExecutorPtr input, size_t & executor_index, ASTPt
{
for (const auto & expr : agg_funcs->children)
{
const ASTFunction * func = typeid_cast<const ASTFunction *>(expr.get());
const auto * func = typeid_cast<const ASTFunction *>(expr.get());
if (!func || !AggregateFunctionFactory::instance().isAggregateFunctionName(func->name))
{
need_append_project = true;
Expand Down Expand Up @@ -1490,7 +1510,7 @@ ExecutorPtr compileProject(ExecutorPtr input, size_t & executor_index, ASTPtr se
output_schema.emplace_back(ft->first, ft->second);
continue;
}
const ASTFunction * func = typeid_cast<const ASTFunction *>(expr.get());
const auto * func = typeid_cast<const ASTFunction *>(expr.get());
if (func && AggregateFunctionFactory::instance().isAggregateFunctionName(func->name))
{
throw Exception("No such agg " + func->getColumnName(), ErrorCodes::NO_SUCH_COLUMN_IN_TABLE);
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Debug/astToExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ ExecutorPtr compileExchangeSender(ExecutorPtr input, size_t & executor_index, ti

ExecutorPtr compileExchangeReceiver(size_t & executor_index, DAGSchema schema);

void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & field, tipb::Expr * expr, Int32 collator_id);

//TODO: add compileWindow

} // namespace DB
24 changes: 1 addition & 23 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,15 +1279,7 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, const Expressi
}
else if (isScalarFunctionExpr(expr))
{
const String & func_name = getFunctionName(expr);
if (DAGExpressionAnalyzerHelper::function_builder_map.count(func_name) != 0)
{
ret = DAGExpressionAnalyzerHelper::function_builder_map[func_name](this, expr, actions);
}
else
{
ret = buildFunction(expr, actions);
}
ret = DAGExpressionAnalyzerHelper::buildFunction(this, expr, actions);
}
else
{
Expand Down Expand Up @@ -1341,18 +1333,4 @@ String DAGExpressionAnalyzer::buildTupleFunctionForGroupConcat(
return applyFunction(func_name, argument_names, actions, nullptr);
}

String DAGExpressionAnalyzer::buildFunction(
const tipb::Expr & expr,
const ExpressionActionsPtr & actions)
{
const String & func_name = getFunctionName(expr);
Names argument_names;
for (const auto & child : expr.children())
{
String name = getActions(child, actions);
argument_names.push_back(name);
}
return applyFunction(func_name, argument_names, actions, getCollatorFromExpr(expr));
}

} // namespace DB
4 changes: 0 additions & 4 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,6 @@ class DAGExpressionAnalyzer : private boost::noncopyable
const ExpressionActionsPtr & actions,
const String & column_name);

String buildFunction(
const tipb::Expr & expr,
const ExpressionActionsPtr & actions);

String buildFilterColumn(
const ExpressionActionsPtr & actions,
const std::vector<const tipb::Expr *> & conditions);
Expand Down
Loading

0 comments on commit 973de13

Please sign in to comment.