Skip to content

Commit

Permalink
feat: Operator overload from Function (#408)
Browse files Browse the repository at this point in the history
* feat: rewritten Operator overload from Function

* fix: add test
  • Loading branch information
nikolaymatrosov authored Feb 17, 2024
1 parent 1719809 commit 4cac5f6
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 17 deletions.
2 changes: 1 addition & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {

// check operator overloading
if fns, ok := v.config.Operators[node.Operator]; ok {
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r)
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, v.config.Functions, l, r)
if ok {
return t, info{}
}
Expand Down
38 changes: 31 additions & 7 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/expr-lang/expr/vm/runtime"
)

type FunctionTable map[string]*builtin.Function

type Config struct {
Env any
Types TypesTable
Expand Down Expand Up @@ -85,21 +87,43 @@ func (c *Config) ConstExpr(name string) {
func (c *Config) Check() {
for operator, fns := range c.Operators {
for _, fn := range fns {
fnType, ok := c.Types[fn]
if !ok || fnType.Type.Kind() != reflect.Func {
fnType, foundType := c.Types[fn]
fnFunc, foundFunc := c.Functions[fn]
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, operator))
}
requiredNumIn := 2
if fnType.Method {
requiredNumIn = 3 // As first argument of method is receiver.

if foundType {
checkType(fnType, fn, operator)
}
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
if foundFunc {
checkFunc(fnFunc, fn, operator)
}
}
}
}

func checkType(fnType Tag, fn string, operator string) {
requiredNumIn := 2
if fnType.Method {
requiredNumIn = 3 // As first argument of method is receiver.
}
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
}
}

func checkFunc(fn *builtin.Function, name string, operator string) {
if len(fn.Types) == 0 {
panic(fmt.Errorf("function %s for %s operator misses types", name, operator))
}
for _, t := range fn.Types {
if t.NumIn() != 2 || t.NumOut() != 1 {
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator))
}
}
}

func (c *Config) IsOverridden(name string) bool {
if _, ok := c.Functions[name]; ok {
return true
Expand Down
55 changes: 46 additions & 9 deletions conf/operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,65 @@ import (
// Functions should be provided in the environment to allow operator overloading.
type OperatorsTable map[string][]string

func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
func FindSuitableOperatorOverload(fns []string, types TypesTable, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
t, fn, ok := FindSuitableOperatorOverloadInFunctions(fns, funcs, l, r)
if !ok {
t, fn, ok = FindSuitableOperatorOverloadInTypes(fns, types, l, r)
}
return t, fn, ok
}

func FindSuitableOperatorOverloadInTypes(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range fns {
fnType := types[fn]
fnType, ok := types[fn]
if !ok {
continue
}
firstInIndex := 0
if fnType.Method {
firstInIndex = 1 // As first argument to method is receiver.
}
firstArgType := fnType.Type.In(firstInIndex)
secondArgType := fnType.Type.In(firstInIndex + 1)
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
return nil, "", false
}

firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
if firstArgumentFit && secondArgumentFit {
return fnType.Type.Out(0), fn, true
func FindSuitableOperatorOverloadInFunctions(fns []string, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
for _, fn := range fns {
fnType, ok := funcs[fn]
if !ok {
continue
}
firstInIndex := 0
for _, overload := range fnType.Types {
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
if done {
return ret, fn, true
}
}
}
return nil, "", false
}

func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
firstArgType := t.In(firstInIndex)
secondArgType := t.In(firstInIndex + 1)

firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
if firstArgumentFit && secondArgumentFit {
return t.Out(0), true
}
return nil, false
}

type OperatorPatcher struct {
Operators OperatorsTable
Types TypesTable
Functions FunctionTable
}

func (p *OperatorPatcher) Visit(node *ast.Node) {
Expand All @@ -48,7 +85,7 @@ func (p *OperatorPatcher) Visit(node *ast.Node) {
leftType := binaryNode.Left.Type()
rightType := binaryNode.Right.Type()

ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, leftType, rightType)
ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, p.Functions, leftType, rightType)
if ok {
newNode := &ast.CallNode{
Callee: &ast.IdentifierNode{Value: fn},
Expand Down
1 change: 1 addition & 0 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
config.Visitors = append(config.Visitors, &conf.OperatorPatcher{
Operators: config.Operators,
Types: config.Types,
Functions: config.Functions,
})
}

Expand Down
167 changes: 167 additions & 0 deletions test/operator/operator_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package operator_test

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -55,3 +56,169 @@ func TestOperator_interface(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, output)
}

type Value struct {
Int int
}

func TestOperator_Function(t *testing.T) {
env := map[string]interface{}{
"foo": Value{1},
"bar": Value{2},
}

tests := []struct {
input string
want int
}{
{
input: `foo + bar`,
want: 3,
},
{
input: `2 + 4`,
want: 6,
},
}

for _, tt := range tests {
t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) {
program, err := expr.Compile(
tt.input,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
expr.Function("AddInt", func(args ...interface{}) (interface{}, error) {
return args[0].(int) + args[1].(int), nil
},
new(func(_ int, __ int) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, tt.want, output)
})
}

}

func TestOperator_Function_WithTypes(t *testing.T) {
env := map[string]interface{}{
"foo": Value{1},
"bar": Value{2},
}

require.PanicsWithError(t, `function Add for + operator misses types`, func() {
_, _ = expr.Compile(
`foo + bar`,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
}),
)
})

require.PanicsWithError(t, `function Add for + operator does not have a correct signature`, func() {
_, _ = expr.Compile(
`foo + bar`,
expr.Env(env),
expr.Operator("+", "Add", "AddInt"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value) int),
),
)
})

}

func TestOperator_FunctionOverTypesPrecedence(t *testing.T) {
env := struct {
Add func(a, b int) int
}{
Add: func(a, b int) int {
return a + b
},
}

program, err := expr.Compile(
`1 + 2`,
expr.Env(env),
expr.Operator("+", "Add"),
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
// Wierd function that returns 100 + a + b in testing purposes.
return args[0].(int) + args[1].(int) + 100, nil
},
new(func(_ int, __ int) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 103, output)
}

func TestOperator_CanBeDefinedEitherInTypesOrInFunctions(t *testing.T) {
env := struct {
Add func(a, b int) int
}{
Add: func(a, b int) int {
return a + b
},
}

program, err := expr.Compile(
`1 + 2`,
expr.Env(env),
expr.Operator("+", "Add", "AddValues"),
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 3, output)
}

func TestOperator_Polymorphic(t *testing.T) {
env := struct {
Add func(a, b int) int
Foo Value
Bar Value
}{
Add: func(a, b int) int {
return a + b
},
Foo: Value{1},
Bar: Value{2},
}

program, err := expr.Compile(
`1 + 2 + (Foo + Bar)`,
expr.Env(env),
expr.Operator("+", "Add", "AddValues"),
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
return args[0].(Value).Int + args[1].(Value).Int, nil
},
new(func(_ Value, __ Value) int),
),
)
require.NoError(t, err)

output, err := expr.Run(program, env)
require.NoError(t, err)
require.Equal(t, 6, output)
}

0 comments on commit 4cac5f6

Please sign in to comment.