Skip to content

Commit

Permalink
Merge branch 'master' into extension_bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Oct 24, 2022
2 parents cef6bca + ad0f7d2 commit ad7f253
Show file tree
Hide file tree
Showing 9 changed files with 596 additions and 0 deletions.
3 changes: 3 additions & 0 deletions expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ go_library(
"explain.go",
"expr_to_pb.go",
"expression.go",
"extension.go",
"function_traits.go",
"helper.go",
"partition_pruner.go",
Expand All @@ -66,6 +67,7 @@ go_library(
deps = [
"//config",
"//errno",
"//extension",
"//kv",
"//parser",
"//parser/ast",
Expand Down Expand Up @@ -97,6 +99,7 @@ go_library(
"//util/parser",
"//util/plancodec",
"//util/printer",
"//util/sem",
"//util/set",
"//util/size",
"//util/sqlexec",
Expand Down
7 changes: 7 additions & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,13 @@ func GetBuiltinList() []string {
}
res = append(res, funcName)
}

extensionFuncs.Range(func(key, _ any) bool {
funcName := key.(string)
res = append(res, funcName)
return true
})

slices.Sort(res)
return res
}
Expand Down
185 changes: 185 additions & 0 deletions expression/extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"strings"
"sync"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sem"
)

var extensionFuncs sync.Map

func registerExtensionFunc(def *extension.FunctionDef) error {
if def == nil {
return errors.New("extension function def is nil")
}

if def.Name == "" {
return errors.New("extension function name should not be empty")
}

lowerName := strings.ToLower(def.Name)
if _, ok := funcs[lowerName]; ok {
return errors.Errorf("extension function name '%s' conflict with builtin", def.Name)
}

class, err := newExtensionFuncClass(def)
if err != nil {
return err
}

_, exist := extensionFuncs.LoadOrStore(lowerName, class)
if exist {
return errors.Errorf("duplicated extension function name '%s'", def.Name)
}

return nil
}

func removeExtensionFunc(name string) {
extensionFuncs.Delete(name)
}

type extensionFuncClass struct {
baseFunctionClass
funcDef extension.FunctionDef
flen int
}

func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, error) {
var flen int
switch def.EvalTp {
case types.ETString:
flen = mysql.MaxFieldVarCharLength
if def.EvalStringFunc == nil {
return nil, errors.New("eval function is nil")
}
case types.ETInt:
flen = mysql.MaxIntWidth
if def.EvalIntFunc == nil {
return nil, errors.New("eval function is nil")
}
default:
return nil, errors.Errorf("unsupported extension function ret type: '%v'", def.EvalTp)
}

return &extensionFuncClass{
baseFunctionClass: baseFunctionClass{def.Name, len(def.ArgTps), len(def.ArgTps)},
flen: flen,
funcDef: *def,
}, nil
}

func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.checkPrivileges(ctx); err != nil {
return nil, err
}

if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, c.funcDef.EvalTp, c.funcDef.ArgTps...)
if err != nil {
return nil, err
}
bf.tp.SetFlen(c.flen)
sig := &extensionFuncSig{bf, c.funcDef}
return sig, nil
}

func (c *extensionFuncClass) checkPrivileges(ctx sessionctx.Context) error {
privs := c.funcDef.RequireDynamicPrivileges
if semPrivs := c.funcDef.SemRequireDynamicPrivileges; len(semPrivs) > 0 && sem.IsEnabled() {
privs = semPrivs
}

if len(privs) == 0 {
return nil
}

manager := privilege.GetPrivilegeManager(ctx)
activeRoles := ctx.GetSessionVars().ActiveRoles

for _, priv := range privs {
if !manager.RequestDynamicVerification(activeRoles, priv, false) {
msg := priv
if !sem.IsEnabled() {
msg = "SUPER or " + msg
}
return errSpecificAccessDenied.GenWithStackByArgs(msg)
}
}

return nil
}

var _ extension.FunctionContext = &extensionFuncSig{}

type extensionFuncSig struct {
baseBuiltinFunc
extension.FunctionDef
}

func (b *extensionFuncSig) Clone() builtinFunc {
newSig := &extensionFuncSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.FunctionDef = b.FunctionDef
return newSig
}

func (b *extensionFuncSig) evalString(row chunk.Row) (string, bool, error) {
if b.EvalTp == types.ETString {
return b.EvalStringFunc(b, row)
}
return b.baseBuiltinFunc.evalString(row)
}

func (b *extensionFuncSig) evalInt(row chunk.Row) (int64, bool, error) {
if b.EvalTp == types.ETInt {
return b.EvalIntFunc(b, row)
}
return b.baseBuiltinFunc.evalInt(row)
}

func (b *extensionFuncSig) EvalArgs(row chunk.Row) ([]types.Datum, error) {
if len(b.args) == 0 {
return nil, nil
}

result := make([]types.Datum, 0, len(b.args))
for _, arg := range b.args {
val, err := arg.Eval(row)
if err != nil {
return nil, err
}
result = append(result, val)
}

return result, nil
}

func init() {
extension.RegisterExtensionFunc = registerExtensionFunc
extension.RemoveExtensionFunc = removeExtensionFunc
}
7 changes: 7 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType
}
}
fc, ok := funcs[funcName]
if !ok {
if extFunc, exist := extensionFuncs.Load(funcName); exist {
fc = extFunc.(functionClass)
ok = true
}
}

if !ok {
db := ctx.GetSessionVars().CurrentDB
if db == "" {
Expand Down
8 changes: 8 additions & 0 deletions extension/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go_library(
name = "extension",
srcs = [
"extensions.go",
"function.go",
"manifest.go",
"registry.go",
"util.go",
Expand All @@ -12,6 +13,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//sessionctx/variable",
"//types",
"//util/chunk",
"@com_github_pingcap_errors//:errors",
],
Expand All @@ -21,15 +23,21 @@ go_test(
name = "extension_test",
srcs = [
"bootstrap_test.go",
"function_test.go",
"main_test.go",
"registry_test.go",
],
embed = [":extension"],
deps = [
"//expression",
"//parser/auth",
"//privilege/privileges",
"//sessionctx/variable",
"//testkit",
"//testkit/testsetup",
"//types",
"//util/chunk",
"//util/sem",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_uber_go_goleak//:goleak",
Expand Down
1 change: 1 addition & 0 deletions extension/extensionimpl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ go_library(
"//kv",
"//util/chunk",
"//util/sqlexec",
"@com_github_pingcap_errors//:errors",
],
)
48 changes: 48 additions & 0 deletions extension/function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package extension

import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
)

// FunctionContext is a interface to provide context to the custom function
type FunctionContext interface {
EvalArgs(row chunk.Row) ([]types.Datum, error)
}

// FunctionDef is the definition for the custom function
type FunctionDef struct {
Name string
EvalTp types.EvalType
ArgTps []types.EvalType
// EvalStringFunc is the eval function when `EvalTp` is `types.ETString`
EvalStringFunc func(ctx FunctionContext, row chunk.Row) (string, bool, error)
// EvalIntFunc is the eval function when `EvalTp` is `types.ETInt`
EvalIntFunc func(ctx FunctionContext, row chunk.Row) (int64, bool, error)
// RequireDynamicPrivileges is the dynamic privileges needed to invoke the function
// If `RequireDynamicPrivileges` is empty, it means every one can invoke this function
RequireDynamicPrivileges []string
// SemRequireDynamicPrivileges is the dynamic privileges needed to invoke the function in sem mode
// If `SemRequireDynamicPrivileges` is empty, `DynamicPrivileges` will be used in sem mode
SemRequireDynamicPrivileges []string
}

// RegisterExtensionFunc is to avoid dependency cycle
var RegisterExtensionFunc func(*FunctionDef) error

// RemoveExtensionFunc is to avoid dependency cycle
var RemoveExtensionFunc func(string)
Loading

0 comments on commit ad7f253

Please sign in to comment.