Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support external udaf function #2825

Merged
merged 31 commits into from
Mar 22, 2023
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c18bc2c
feat: support external udaf
dl239 Nov 24, 2022
011c229
refact: update type_ir_builder
dl239 Nov 25, 2022
392b404
refact: refact
dl239 Nov 29, 2022
360438b
Merge branch 'main' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Nov 29, 2022
098e5a7
feat: add node
dl239 Nov 29, 2022
e0030c7
feat: build udaf call
dl239 Nov 30, 2022
66df39d
feat: add udaf registry
dl239 Nov 30, 2022
a98d7b3
feat: support drop
dl239 Dec 7, 2022
e823ad3
Merge branch 'main' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Dec 7, 2022
aed213e
refact: rm unused code
dl239 Dec 7, 2022
539bf6f
refact: rm unused code
dl239 Dec 7, 2022
d480f1c
docs: add udaf docs
dl239 Dec 8, 2022
2d2b319
docs: update
dl239 Dec 8, 2022
2bce9b6
feat: support null args
dl239 Dec 9, 2022
be68dbf
Merge branch 'feat/udaf' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Dec 9, 2022
4f50ac8
fix: fix test case
dl239 Dec 12, 2022
7f9e01c
feat: return result in arg if return_nullable is true
dl239 Dec 12, 2022
5ce768c
docs: update the docs
dl239 Dec 12, 2022
c7413cb
fix: fix test case
dl239 Dec 13, 2022
b89eeb1
Merge branch 'feat/udaf' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Dec 13, 2022
c31b2d8
Merge branch 'main' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Dec 13, 2022
5f99997
merge main
dl239 Dec 20, 2022
1483c9c
Merge branch 'feat/udaf' of github.com:dl239/OpenMLDB into feat/udaf
dl239 Dec 20, 2022
05a1b35
merge main
dl239 Mar 20, 2023
5e1a8c5
fix: fix compile
dl239 Mar 20, 2023
f583a41
fix: fix ExternalUdfUtil
dl239 Mar 20, 2023
497e527
fix: fix comment
dl239 Mar 21, 2023
cede5bd
fix: fix comment
dl239 Mar 21, 2023
d3719f8
fix: fix style
dl239 Mar 21, 2023
b2b8000
fix: fix compile
dl239 Mar 21, 2023
926d158
fix: fix compile
dl239 Mar 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: support drop
dl239 committed Dec 7, 2022
commit a98d7b368d3fab5ebd386a84c8b0fe470cdf5acd
2 changes: 1 addition & 1 deletion hybridse/include/node/node_manager.h
Original file line number Diff line number Diff line change
@@ -360,7 +360,7 @@ class NodeManager {
DynamicUdafFnDefNode *MakeDynamicUdafFnDefNode(
const std::string &function_name, const std::vector<const TypeNode *> &arg_types,
ExternalFnDefNode *init_context_node, ExternalFnDefNode *init_node,
ExternalFnDefNode *update_node, ExternalFnDefNode *output_node);
FnDefNode *update_node, ExternalFnDefNode *output_node);

ExternalFnDefNode *MakeUnresolvedFnDefNode(
const std::string &function_name);
6 changes: 3 additions & 3 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
@@ -2691,7 +2691,7 @@ class DynamicUdafFnDefNode : public FnDefNode {
public:
DynamicUdafFnDefNode(const std::string &name, const std::vector<const TypeNode *> &arg_types,
ExternalFnDefNode *init_context_node, ExternalFnDefNode *init_node,
ExternalFnDefNode *update_node, ExternalFnDefNode *output_node)
FnDefNode *update_node, ExternalFnDefNode *output_node)
: FnDefNode(kDynamicUdafFnDef),
name_(name),
arg_types_(arg_types),
@@ -2707,7 +2707,7 @@ class DynamicUdafFnDefNode : public FnDefNode {

ExternalFnDefNode* init_contex_func() const { return init_context_node_; }
ExternalFnDefNode* init_func() const { return init_node_; }
ExternalFnDefNode* update_func() const { return update_node_; }
FnDefNode* update_func() const { return update_node_; }
ExternalFnDefNode* output_func() const { return output_node_; }

base::Status Validate(const std::vector<const TypeNode *> &arg_types) const override;
@@ -2756,7 +2756,7 @@ class DynamicUdafFnDefNode : public FnDefNode {
std::vector<const TypeNode *> arg_types_;
ExternalFnDefNode *init_context_node_;
ExternalFnDefNode *init_node_;
ExternalFnDefNode *update_node_;
FnDefNode *update_node_;
ExternalFnDefNode *output_node_;
};

11 changes: 7 additions & 4 deletions hybridse/src/codegen/udf_ir_builder.cc
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ Status UdfIRBuilder::BuildCall(
// sanity checks
auto status = fn->Validate(arg_types);
if (!status.isOK()) {
LOG(WARNING) << "Validation error: " << status;
LOG(WARNING) << "Validation error: " << fn->GetName() << status;
}

switch (fn->GetType()) {
@@ -776,7 +776,6 @@ Status UdfIRBuilder::BuildDynamicUdafCall(const node::DynamicUdafFnDefNode* fn,
}
Status status;


std::vector<::llvm::Value*> list_ptrs;
for (size_t i = 0; i < input_num; ++i) {
list_ptrs.push_back(args[i].GetValue(ctx_));
@@ -820,16 +819,20 @@ Status UdfIRBuilder::BuildDynamicUdafCall(const node::DynamicUdafFnDefNode* fn,
ListIRBuilder iter_next_builder(body_begin_block, nullptr);
UdfIRBuilder sub_udf_builder(ctx_, frame_arg_, frame_);

std::vector<NativeValue> update_args;
std::vector<NativeValue> update_args = {udfcontext_output};
for (size_t i = 0; i < input_num; ++i) {
NativeValue next_val;
CHECK_STATUS(iter_next_builder.BuildIteratorNext(
iterators[i], elem_types[i], elem_nullable[i], &next_val));
update_args.push_back(next_val);
}
std::vector<const node::TypeNode*> update_arg_types = { fn->update_func()->GetArgType(0) };
for (size_t i = 0; i < input_num; ++i) {
update_arg_types.push_back(elem_types[i]);
}
NativeValue update_value;
CHECK_TRUE(fn->update_func() != nullptr, kCodegenError);
CHECK_STATUS(sub_udf_builder.BuildExternCall(fn->update_func(), update_args, &update_value));
CHECK_STATUS(sub_udf_builder.BuildCall(fn->update_func(), update_arg_types, update_args, &update_value));
return Status::OK();
}));

2 changes: 1 addition & 1 deletion hybridse/src/node/node_manager.cc
Original file line number Diff line number Diff line change
@@ -1095,7 +1095,7 @@ SqlNode *NodeManager::MakeInputParameterNode(bool is_constant, const std::string
DynamicUdafFnDefNode *NodeManager::MakeDynamicUdafFnDefNode(const std::string &function_name,
const std::vector<const TypeNode *> &arg_types,
ExternalFnDefNode *init_context_node, ExternalFnDefNode *init_node,
ExternalFnDefNode *update_node, ExternalFnDefNode *output_node) {
FnDefNode *update_node, ExternalFnDefNode *output_node) {
return RegisterNode(new DynamicUdafFnDefNode(function_name, arg_types, init_context_node,
init_node, update_node, output_node));
}
40 changes: 0 additions & 40 deletions hybridse/src/node/sql_node.cc
Original file line number Diff line number Diff line change
@@ -2610,46 +2610,6 @@ Status DynamicUdafFnDefNode::Validate(const std::vector<const TypeNode *> &arg_t
}
// init check
CHECK_TRUE(GetStateType() != nullptr, kTypeError, "State type not inferred");
if (init_func() == nullptr) {
CHECK_TRUE(arg_types_.size() == 1, kTypeError, "Only support single input if init not set");
} else {
CHECK_TRUE(init_func()->GetReturnType() != nullptr, kTypeError)
CHECK_TRUE(init_func()->GetReturnType()->Equals(GetStateType()), kTypeError, "Init type expect to be ",
GetStateType()->GetName(), ", but get ", init_func()->GetReturnType()->GetName());
}
// update check
CHECK_TRUE(update_func()->GetArgSize() == 1 + arg_types_.size(), kTypeError, "Update should take ",
1 + arg_types_.size(), ", get ", update_func()->GetArgSize());
for (size_t i = 0; i < arg_types_.size() + 1; ++i) {
auto arg_type = update_func()->GetArgType(i);
CHECK_TRUE(arg_type != nullptr, kTypeError, i, "th update argument type is not inferred");
if (i == 0) {
CHECK_TRUE(arg_type->Equals(GetStateType()), kTypeError, "Update's first argument type should be ",
GetStateType()->GetName(), ", but get ", arg_type->GetName());
} else {
CHECK_TRUE(arg_type->Equals(GetElementType(i - 1)), kTypeError, "Update's ", i,
"th argument type should be ", GetElementType(i - 1), ", but get ", arg_type->GetName());
}
}
// output check
if (output_func() != nullptr) {
CHECK_TRUE(output_func()->GetArgSize() == 1, kTypeError, "Output should take 1 arguments, but get ",
output_func()->GetArgSize());
CHECK_TRUE(output_func()->GetArgType(0) != nullptr, kTypeError);
CHECK_TRUE(output_func()->GetArgType(0)->Equals(GetStateType()), kTypeError,
"Output's 0th argument type should be ", GetStateType(), ", but get ",
output_func()->GetArgType(0)->GetName());
CHECK_TRUE(output_func()->GetReturnType() != nullptr, kTypeError);
}
// actual args check
CHECK_TRUE(arg_types.size() == arg_types_.size(), kTypeError, GetName(), " expect ", arg_types_.size(),
" inputs, but get ", arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] != nullptr) {
CHECK_TRUE(arg_types_[i]->Equals(arg_types[i]), kTypeError, GetName(), "'s ", i, "th argument expect ",
arg_types_[i]->GetName(), ", but get ", arg_types[i]->GetName());
}
}
return Status::OK();
}

56 changes: 34 additions & 22 deletions hybridse/src/passes/lambdafy_projects.cc
Original file line number Diff line number Diff line change
@@ -308,18 +308,8 @@ Status LambdafyProjects::VisitAggExpr(node::CallExprNode* call,
CHECK_STATUS(ctx_->library()->ResolveFunction(
fn->function_name(), agg_original_args, nm, &fn_def),
"Resolve original udaf for ", fn->function_name(), " failed");
auto origin_udaf = dynamic_cast<node::UdafDefNode*>(fn_def);
CHECK_TRUE(origin_udaf != nullptr, kCodegenError, fn->function_name(),
" is not an udaf");

// refer to original udaf's functionalities
auto ori_update_fn = origin_udaf->update_func();
auto ori_merge_fn = origin_udaf->merge_func();
auto ori_output_fn = origin_udaf->output_func();
auto ori_init = origin_udaf->init_expr();
CHECK_TRUE(
ori_init != nullptr, kCodegenError,
"Do not support use first element as init state for lambdafy udaf");
CHECK_TRUE(fn_def != nullptr && (fn_def->GetType() == node::kUdafDef || fn_def->GetType() == node::kDynamicUdafFnDef),
kCodegenError, fn->function_name(), " is not an udaf")

// build new udaf update function
std::vector<node::ExprNode*> actual_update_args;
@@ -362,12 +352,6 @@ Status LambdafyProjects::VisitAggExpr(node::CallExprNode* call,
actual_update_args.push_back(transformed_child[i]);
}
}

// wrap actual update call into proxy update function
auto update_body =
nm->MakeFuncNode(ori_update_fn, actual_update_args, nullptr);
auto update_func = nm->MakeLambdaNode(proxy_update_args, update_body);

std::string new_udaf_name = "window_agg_$";
new_udaf_name.append(fn->function_name());
new_udaf_name.append("<");
@@ -379,10 +363,38 @@ Status LambdafyProjects::VisitAggExpr(node::CallExprNode* call,
}
new_udaf_name.append(">");

auto new_udaf =
nm->MakeUdafDefNode(new_udaf_name, proxy_udaf_arg_types, ori_init,
update_func, ori_merge_fn, ori_output_fn);
*out = nm->MakeFuncNode(new_udaf, proxy_udaf_args, nullptr);
if (fn_def->GetType() == node::kUdafDef) {
auto origin_udaf = dynamic_cast<node::UdafDefNode*>(fn_def);
CHECK_TRUE(origin_udaf != nullptr, kCodegenError, fn->function_name(), " is not an udaf")
auto ori_update_fn = origin_udaf->update_func();
auto ori_merge_fn = origin_udaf->merge_func();
auto ori_output_fn = origin_udaf->output_func();
auto ori_init = origin_udaf->init_expr();
CHECK_TRUE(ori_init != nullptr, kCodegenError,
"Do not support use first element as init state for lambdafy udaf");

// wrap actual update call into proxy update function
auto update_body = nm->MakeFuncNode(ori_update_fn, actual_update_args, nullptr);
auto update_func = nm->MakeLambdaNode(proxy_update_args, update_body);
auto new_udaf = nm->MakeUdafDefNode(new_udaf_name, proxy_udaf_arg_types,
ori_init, update_func, ori_merge_fn, ori_output_fn);
*out = nm->MakeFuncNode(new_udaf, proxy_udaf_args, nullptr);
} else {
auto origin_udaf = dynamic_cast<node::DynamicUdafFnDefNode*>(fn_def);
CHECK_TRUE(origin_udaf != nullptr, kCodegenError, fn->function_name(), " is not an udaf")
auto ori_init_contex_fn = origin_udaf->init_contex_func();
auto ori_init_fn = origin_udaf->init_func();
auto ori_update_fn = origin_udaf->update_func();
auto ori_output_fn = origin_udaf->output_func();

// wrap actual update call into proxy update function
proxy_update_args[0]->SetOutputType(ori_update_fn->GetArgType(0));
auto update_body = nm->MakeFuncNode(ori_update_fn, actual_update_args, nullptr);
auto update_func = nm->MakeLambdaNode(proxy_update_args, update_body);
auto new_udaf = nm->MakeDynamicUdafFnDefNode(new_udaf_name, proxy_udaf_arg_types,
ori_init_contex_fn, ori_init_fn, update_func, ori_output_fn);
*out = nm->MakeFuncNode(new_udaf, proxy_udaf_args, nullptr);
}
return Status::OK();
}

15 changes: 9 additions & 6 deletions hybridse/src/udf/dynamic_lib_manager.cc
Original file line number Diff line number Diff line change
@@ -51,22 +51,25 @@ base::Status DynamicLibManager::ExtractFunction(const std::string& name, bool is
handle_map_.emplace(file, so_handle);
}
if (is_aggregate) {
auto init_fun = dlsym(so_handle->handle, std::string(name + "_init").c_str());
std::string init_fun_name = name + "_init";
auto init_fun = dlsym(so_handle->handle, init_fun_name.c_str());
if (init_fun == nullptr) {
RemoveHandler(file);
return {common::kExternalUDFError, "can not find the init function: " + name};
return {common::kExternalUDFError, "can not find the init function: " + init_fun_name};
}
funs->emplace_back(init_fun);
auto update_fun = dlsym(so_handle->handle, std::string(name + "_update").c_str());
std::string update_fun_name = name + "_update";
auto update_fun = dlsym(so_handle->handle, update_fun_name.c_str());
if (update_fun == nullptr) {
RemoveHandler(file);
return {common::kExternalUDFError, "can not find the update function: " + name};
return {common::kExternalUDFError, "can not find the update function: " + update_fun_name};
}
funs->emplace_back(update_fun);
auto output_fun = dlsym(so_handle->handle, std::string(name + "_output").c_str());
std::string output_fun_name = name + "_output";
auto output_fun = dlsym(so_handle->handle, output_fun_name.c_str());
if (output_fun == nullptr) {
RemoveHandler(file);
return {common::kExternalUDFError, "can not find the output function: " + name};
return {common::kExternalUDFError, "can not find the output function: " + output_fun_name};
}
funs->emplace_back(output_fun);
} else {
41 changes: 35 additions & 6 deletions hybridse/src/udf/udf_library.cc
Original file line number Diff line number Diff line change
@@ -222,10 +222,20 @@ Status UdfLibrary::RegisterDynamicUdf(const std::string& name, node::DataType re
Status status;
if (is_aggregate) {
CHECK_TRUE(funs.size() == 3, kCodegenError, "cannot find function in so")
DynamicUdafRegistryHelper helper(canon_name, this, return_type, arg_types,
/*DynamicUdafRegistryHelper helper(canon_name, this, return_type, arg_types,
reinterpret_cast<void*>(static_cast<void (*)(UDFContext* context)>(udf::v1::init_udfcontext)),
funs[0], funs[1], funs[2]);
status = helper.Register();
status = helper.Register();*/
void* init_context_ptr =
reinterpret_cast<void*>(static_cast<void (*)(UDFContext* context)>(udf::v1::init_udfcontext));
DynamicUdafRegistryHelperImpl helper(canon_name, this, return_type, arg_types);
std::string lib_name = canon_name;
for (const auto type : arg_types) {
lib_name.append(".").append(node::DataTypeName(type));
}
helper.init(lib_name + ".init", init_context_ptr, funs[0])
.update(lib_name + ".update", funs[1])
.output(lib_name + ".output", funs[2]);
} else {
CHECK_TRUE(!funs.empty() && funs[0] != nullptr, kCodegenError, name + " is nullptr")
void* fn = funs[0];
@@ -247,12 +257,31 @@ Status UdfLibrary::RemoveDynamicUdf(const std::string& name, const std::vector<n
for (const auto type : arg_types) {
lib_name.append(".").append(node::DataTypeName(type));
}
std::lock_guard<std::mutex> lock(mu_);
if (table_.erase(canonical_name) <= 0) {
if (!HasFunction(canonical_name)) {
return Status(kCodegenError, "can not find the function " + canonical_name);
}
if (external_symbols_.erase(lib_name) <= 0) {
return Status(kCodegenError, "can not find the function " + lib_name);
if (IsUdaf(canonical_name)) {
std::lock_guard<std::mutex> lock(mu_);
if (table_.erase(canonical_name) <= 0) {
return Status(kCodegenError, "can not find the function " + canonical_name);
}
if (external_symbols_.erase(lib_name + ".init") <= 0) {
return Status(kCodegenError, "can not find the init function " + lib_name);
}
if (external_symbols_.erase(lib_name + ".update") <= 0) {
return Status(kCodegenError, "can not find the update function " + lib_name);
}
if (external_symbols_.erase(lib_name + ".output") <= 0) {
return Status(kCodegenError, "can not find the output function " + lib_name);
}
} else {
std::lock_guard<std::mutex> lock(mu_);
if (table_.erase(canonical_name) <= 0) {
return Status(kCodegenError, "can not find the function " + canonical_name);
}
if (external_symbols_.erase(lib_name) <= 0) {
return Status(kCodegenError, "can not find the function " + lib_name);
}
}
return lib_manager_.RemoveHandler(file);
}
154 changes: 141 additions & 13 deletions hybridse/src/udf/udf_registry.cc
Original file line number Diff line number Diff line change
@@ -267,6 +267,30 @@ Status DynamicUdafRegistry::ResolveFunction(UdfResolveContext* ctx,
CHECK_TRUE(extern_def_->GetReturnType() != nullptr, kCodegenError,
"No return type specified for ", extern_def_->GetName());
DLOG(INFO) << "Resolve udaf \"" << name() << "\" -> " << extern_def_->GetFlatString();
/*auto nm = ctx->node_manager();
auto state_arg = nm->MakeExprIdNode("state");
state_arg->SetOutputType(extern_def_->GetStateType());
state_arg->SetNullable(0);
std::vector<node::ExprNode*> update_args;
update_args.push_back(state_arg);
std::vector<const node::TypeNode*> list_types;
for (size_t i = 0; i < ctx->arg_size(); ++i) {
auto elem_arg = nm->MakeExprIdNode("elem_" + std::to_string(i));
auto list_type = ctx->arg_type(i);
CHECK_TRUE(list_type != nullptr && list_type->base() == node::kList,
kCodegenError);
elem_arg->SetOutputType(list_type->GetGenericType(0));
elem_arg->SetNullable(list_type->IsGenericNullable(0));
update_args.push_back(elem_arg);
list_types.push_back(list_type);
}
UdfResolveContext update_ctx(update_args, nm, ctx->library());
CHECK_TRUE(extern_def_->update_func() != nullptr, kCodegenError);
CHECK_STATUS(
extern_def_->update_func()->ResolveFunction(&update_ctx, &update_func),
"Resolve update function of ", name(), " failed");
*result = nm->MakeDynamicUdafFnDefNode(name(), list_types, extern_def_->init_contex_func(),
extern_def_->init_func(), update_func, extern_def_->output_func());*/
*result = extern_def_;
return Status::OK();
}
@@ -388,9 +412,9 @@ DynamicUdafRegistryHelper::DynamicUdafRegistryHelper(const std::string& basename
return_type_ = nm->MakeTypeNode(return_type);
for (const auto type : arg_types) {
auto type_node = nm->MakeTypeNode(type);
arg_types_.emplace_back(type_node);
arg_types_.push_back(type_node);
fn_name_.append(".").append(type_node->GetName());
arg_nullable_.emplace_back(0);
arg_nullable_.push_back(0);
}
switch (return_type) {
case node::kVarchar:
@@ -403,6 +427,15 @@ DynamicUdafRegistryHelper::DynamicUdafRegistryHelper(const std::string& basename
}
}

std::string DynamicUdafRegistryHelper::GetFunName(const std::string& base_name,
const std::vector<const node::TypeNode*>& arg_types) {
std::string fn_name = base_name;
for (auto type_node : arg_types) {
fn_name.append(".").append(type_node->GetName());
}
return fn_name;
}

Status DynamicUdafRegistryHelper::Register() {
if (udfcontext_fun_ptr_ == nullptr || init_fn_ptr_ == nullptr
|| update_fn_ptr_ == nullptr || output_fn_ptr_ == nullptr) {
@@ -419,34 +452,129 @@ Status DynamicUdafRegistryHelper::Register() {
init_context_fn_name, udfcontext_fun_ptr_, type_node, false, {}, {}, -1, true);

auto void_type_node = node_manager()->MakeTypeNode(node::DataType::kVoid);
auto ptr_type_node = node_manager()->MakeTypeNode(node::DataType::kInt8Ptr);

std::string init_fn_name = fn_name_ + "_init.udfcontext";
std::string init_fn_name = GetFunName(name() + "_init", {type_node});
auto init_node = node_manager()->MakeExternalFnDefNode(init_fn_name, init_fn_ptr_,
void_type_node, false, {ptr_type_node}, {0}, -1, false);
std::string update_fn_name = fn_name_ + "_update.udfcontext";
std::vector<const node::TypeNode *> new_arg_types = {ptr_type_node};
void_type_node, false, {type_node}, {0}, -1, false);
std::vector<const node::TypeNode *> new_arg_types = {type_node};
new_arg_types.insert(new_arg_types.end(), arg_types_.begin(), arg_types_.end());
std::vector<int> new_arg_nullable = {0};
new_arg_nullable.insert(new_arg_nullable.end(), arg_nullable_.begin(), arg_nullable_.end());
std::string update_fn_name = GetFunName(name() + "_update", new_arg_types);
auto update_node = node_manager()->MakeExternalFnDefNode(update_fn_name, update_fn_ptr_,
void_type_node, false, new_arg_types, new_arg_nullable, -1, false);

std::string output_fn_name = fn_name_ + "_output.udfcontext";
std::string output_fn_name = GetFunName(name() + "_output", {type_node});
auto output_node = node_manager()->MakeExternalFnDefNode(output_fn_name, output_fn_ptr_,
return_type_, return_by_arg_, {ptr_type_node}, {0}, -1, false);

return_type_, return_by_arg_, {type_node}, {0}, -1, false);

std::vector<const node::TypeNode*> input_list_types;
for (auto elem_ty : arg_types_) {
input_list_types.push_back(node_manager()->MakeTypeNode(node::kList, elem_ty));
}
auto def = node_manager()->MakeDynamicUdafFnDefNode(
fn_name_, arg_types_, init_context_node, init_node, update_node, output_node);
fn_name_, input_list_types, init_context_node, init_node, update_node, output_node);
auto registry = std::make_shared<DynamicUdafRegistry>(name(), def);
library()->AddExternalFunction(init_fn_name, init_fn_ptr_);
library()->AddExternalFunction(update_fn_name, update_fn_ptr_);
library()->AddExternalFunction(output_fn_name, output_fn_ptr_);
this->InsertRegistry(arg_types_, false, registry);
LOG(INFO) << "register function success. name: " << fn_name_ << " return type:" << return_type_->GetName();
this->InsertRegistry(input_list_types, false, registry);
library()->SetIsUdaf(name(), input_list_types.size());
LOG(INFO) << "register function success. name: " << name() << " return type:" << return_type_->GetName();
return Status::OK();
}

DynamicUdafRegistryHelperImpl::DynamicUdafRegistryHelperImpl(const std::string& name, UdfLibrary* library,
node::DataType return_type, const std::vector<node::DataType>& arg_types) : UdfRegistryHelper(name, library) {
auto nm = node_manager();
state_ty_ = nm->MakeOpaqueType(sizeof(UDFContext));
state_nullable_ = false;
update_tys_.push_back(state_ty_);
update_nullable_.push_back(state_nullable_);
for (const auto type : arg_types) {
auto type_node = nm->MakeTypeNode(type);
elem_tys_.push_back(type_node);
elem_nullable_.push_back(0);
update_tys_.push_back(type_node);
update_nullable_.push_back(0);
}
switch (return_type) {
case node::kVarchar:
case node::kDate:
case node::kTimestamp:
return_by_arg_ = true;
break;
default:
return_by_arg_ = false;
}
output_ty_ = nm->MakeTypeNode(return_type);
output_nullable_ = false;
}

DynamicUdafRegistryHelperImpl& DynamicUdafRegistryHelperImpl::init(const std::string& fname,
void* init_context_ptr, void* fn_ptr) {
auto fn = library()->node_manager()->MakeExternalFnDefNode(fname, fn_ptr,
state_ty_, false, {state_ty_}, {0}, -1, false);
library()->AddExternalFunction(fname, fn_ptr);
auto type_node = state_ty_;
udaf_gen_.init_gen =
std::make_shared<DynamicExprUdfGen>([init_context_ptr, type_node, fn](UdfResolveContext* ctx) {
std::string init_context_fn_name = "init_udfcontext.opaque";
auto init_contex_fn = ctx->node_manager()->MakeExternalFnDefNode(init_context_fn_name,
init_context_ptr, type_node, false, {}, {}, -1, true);
auto init_contex_call = ctx->node_manager()->MakeFuncNode(init_contex_fn, {}, nullptr);
return ctx->node_manager()->MakeFuncNode(fn, {init_contex_call}, nullptr);
});
return *this;
}

DynamicUdafRegistryHelperImpl& DynamicUdafRegistryHelperImpl::update(const std::string& fname, void* fn_ptr) {
auto fn = library()->node_manager()->MakeExternalFnDefNode(fname, fn_ptr,
state_ty_, state_nullable_, update_tys_, update_nullable_, -1, false);
auto registry = std::make_shared<ExternalFuncRegistry>(fname, fn);
udaf_gen_.update_gen = registry;
library()->AddExternalFunction(fname, fn_ptr);
return *this;
}

DynamicUdafRegistryHelperImpl& DynamicUdafRegistryHelperImpl::output(const std::string& fname, void* fn_ptr) {
auto fn = library()->node_manager()->MakeExternalFnDefNode(fname, fn_ptr,
output_ty_, output_nullable_, {state_ty_}, {state_nullable_}, -1, return_by_arg_);
auto registry = std::make_shared<ExternalFuncRegistry>(fname, fn);
udaf_gen_.output_gen = registry;
library()->AddExternalFunction(fname, fn_ptr);
return *this;
}

void DynamicUdafRegistryHelperImpl::finalize() {
if (elem_tys_.empty()) {
LOG(WARNING) << "UDAF must take at least one input";
return;
}
if (udaf_gen_.init_gen == nullptr) {
if (!(elem_tys_.size() == 1 && elem_tys_[0]->Equals(state_ty_))) {
LOG(WARNING) << "no init expr provided but input type does not equal to state type";
return;
}
}
if (udaf_gen_.update_gen == nullptr) {
LOG(WARNING) << "update function not specified for " << name();
return;
}
if (udaf_gen_.output_gen == nullptr) {
LOG(WARNING) << "output function not specified for " << name();
return;
}
udaf_gen_.state_type = state_ty_;
udaf_gen_.state_nullable = state_nullable_;
std::vector<const node::TypeNode*> input_list_types;
for (auto elem_ty : elem_tys_) {
input_list_types.push_back(library()->node_manager()->MakeTypeNode(node::kList, elem_ty));
}
auto registry = std::make_shared<UdafRegistry>(name(), udaf_gen_);
this->InsertRegistry(input_list_types, false, registry);
library()->SetIsUdaf(name(), elem_tys_.size());
}

} // namespace udf
} // namespace hybridse
38 changes: 38 additions & 0 deletions hybridse/src/udf/udf_registry.h
Original file line number Diff line number Diff line change
@@ -184,6 +184,15 @@ struct ExprUdfGen : public ExprUdfGenBase {
const FType gen_func;
};

struct DynamicExprUdfGen : public ExprUdfGenBase {
using FType = std::function<ExprNode*(UdfResolveContext*)>;
explicit DynamicExprUdfGen(const FType& f) : gen_func(f) {}
ExprNode* gen(UdfResolveContext* ctx, const std::vector<ExprNode*>& args) override {
return gen_func(ctx);
}
const FType gen_func;
};

template <typename... Args>
struct VariadicExprUdfGen : public ExprUdfGenBase {
using FType = std::function<ExprNode*(
@@ -1271,6 +1280,9 @@ class DynamicUdafRegistryHelper : public UdfRegistryHelper {
void* udfcontext_fun, void* init_fn_ptr, void* update_fn_ptr, void* output_fn_ptr);
Status Register();

private:
std::string GetFunName(const std::string& base_name, const std::vector<const node::TypeNode*>& arg_types);

private:
std::string fn_name_;
void* udfcontext_fun_ptr_;
@@ -1743,6 +1755,32 @@ class UdafRegistryHelperImpl : UdfRegistryHelper {
std::vector<std::string> update_tags_;
};

class DynamicUdafRegistryHelperImpl : public UdfRegistryHelper {
public:
DynamicUdafRegistryHelperImpl(const std::string& basename, UdfLibrary* library,
node::DataType return_type, const std::vector<node::DataType>& arg_types);
~DynamicUdafRegistryHelperImpl() { finalize(); }

DynamicUdafRegistryHelperImpl& init(const std::string& fname, void* init_context_ptr, void* fn_ptr);
DynamicUdafRegistryHelperImpl& update(const std::string& fname, void* fn_ptr);
DynamicUdafRegistryHelperImpl& output(const std::string& fname, void* fn_ptr);

void finalize();

private:
std::vector<const node::TypeNode*> elem_tys_;
std::vector<int> elem_nullable_;
node::TypeNode* state_ty_;
bool state_nullable_;
node::TypeNode* output_ty_;
bool output_nullable_;
bool return_by_arg_;

UdafDefGen udaf_gen_;
std::vector<const node::TypeNode*> update_tys_;
std::vector<int> update_nullable_;
};

template <typename OUT, typename ST, typename... IN>
UdafRegistryHelperImpl<OUT, ST, IN...> UdafRegistryHelper::templates() {
auto helper_impl =
58 changes: 58 additions & 0 deletions src/cmd/single_tablet_test.cc
Original file line number Diff line number Diff line change
@@ -132,6 +132,64 @@ TEST_P(DBSDKTest, CreateFunction) {
absl::StrCat("drop database ", db_name, ";"),
});
}

TEST_P(DBSDKTest, CreateUdafFunction) {
auto cli = GetParam();
cs = cli->cs;
sr = cli->sr;
std::unique_ptr<::openmldb::sdk::SQLClusterRouter> sr_2;
if (cs->IsClusterMode()) {
::openmldb::sdk::ClusterOptions copt;
copt.zk_cluster = mc.GetZkCluster();
copt.zk_path = mc.GetZkPath();
auto cur_cs = new ::openmldb::sdk::ClusterSDK(copt);
cur_cs->Init();
sr_2 = std::make_unique<::openmldb::sdk::SQLClusterRouter>(cur_cs);
sr_2->Init();
ProcessSQLs(sr_2.get(), {"set @@execute_mode = 'online'"});
}
hybridse::sdk::Status status;
std::string so_path = openmldb::test::GetParentDir(openmldb::test::GetExeDir()) + "/libtest_udf.so";
std::string agg_fun_str = absl::StrCat("CREATE AGGREGATE FUNCTION special_sum(x BIGINT) RETURNS BIGINT "
"OPTIONS (FILE='", so_path, "');");
std::string db_name = "test" + GenRand();
std::string tb_name = "t1";
ProcessSQLs(sr,
{
"set @@execute_mode = 'online'",
absl::StrCat("create database ", db_name, ";"),
absl::StrCat("use ", db_name, ";"),
absl::StrCat("create table ", tb_name, " (c1 string, c2 bigint, c3 double);"),
absl::StrCat("insert into ", tb_name, " values ('aab', 11, 1.2);"),
absl::StrCat("insert into ", tb_name, " values ('aab', 12, 1.2);"),
agg_fun_str
});
auto result = sr->ExecuteSQL("show functions", &status);
ExpectResultSetStrEq({{"Name", "Return_type", "Arg_type", "Is_aggregate", "File"},
{"special_sum", "BigInt", "BigInt", "true", so_path}},
result.get());
result = sr->ExecuteSQL("select special_sum(c2) as sumc2 from t1;", &status);
ASSERT_TRUE(status.IsOK());
ASSERT_EQ(1, result->Size());
result->Next();
int64_t value = 0;
result->GetInt64(0, &value);
ASSERT_EQ(value, 38);
if (cs->IsClusterMode()) {
ProcessSQLs(sr_2.get(), {"set @@execute_mode = 'online'", absl::StrCat("use ", db_name, ";")});
// check function in another sdk
result = sr_2->ExecuteSQL("select special_sum(c2) as sumc2 from t1;", &status);
ASSERT_TRUE(status.IsOK()) << status.msg;
ASSERT_EQ(1, result->Size());
result->Next();
int64_t value = 0;
result->GetInt64(0, &value);
ASSERT_EQ(value, 38);
}
ProcessSQLs(sr, {"DROP FUNCTION special_sum;"});
result = sr->ExecuteSQL("select special_sum(c2) as sumc2 from t1;", &status);
ASSERT_FALSE(status.IsOK());
}
#endif

INSTANTIATE_TEST_SUITE_P(DBSDK, DBSDKTest, testing::Values(&standalone_cli, &cluster_cli));
22 changes: 22 additions & 0 deletions src/examples/test_udf.cc
Original file line number Diff line number Diff line change
@@ -44,3 +44,25 @@ void int2str(UDFContext* ctx, int32_t input, StringRef* output) {
output->size_ = tmp.length();
output->data_ = buffer;
}


// udaf example
extern "C"
UDFContext* special_sum_init(UDFContext* ctx) {
ctx->ptr = ctx->pool->Alloc(sizeof(int64_t));
*(reinterpret_cast<int64_t*>(ctx->ptr)) = 10;
return ctx;
}

extern "C"
UDFContext* special_sum_update(UDFContext* ctx, int64_t input) {
int64_t cur = *(reinterpret_cast<int64_t*>(ctx->ptr));
cur += input;
*(reinterpret_cast<int*>(ctx->ptr)) = cur;
return ctx;
}

extern "C"
int64_t special_sum_output(UDFContext* ctx) {
return *(reinterpret_cast<int64_t*>(ctx->ptr)) + 5;
}
3 changes: 0 additions & 3 deletions src/sdk/sql_cluster_router.cc
Original file line number Diff line number Diff line change
@@ -3168,9 +3168,6 @@ hybridse::sdk::Status SQLClusterRouter::HandleCreateFunction(const hybridse::nod
}
fun->add_arg_type(data_type);
}
if (node->IsAggregate()) {
return {StatusCode::kCmdError, "unsupport udaf function"};
}
fun->set_is_aggregate(node->IsAggregate());
auto option = node->Options();
if (!option || option->find("FILE") == option->end()) {