From b26d4b7339ed214b829ec0a22670941702e59828 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Tue, 8 Nov 2022 20:58:23 +0100 Subject: [PATCH] Fix constant folding for floats and ints --- expr_test.go | 4 +- optimizer/fold.go | 182 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 163 insertions(+), 23 deletions(-) diff --git a/expr_test.go b/expr_test.go index 844cd2f87..8359c1f2e 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1386,11 +1386,11 @@ func TestIssue138(t *testing.T) { env := map[string]interface{}{} _, err := expr.Compile(`1 / (1 - 1)`, expr.Env(env)) - require.Error(t, err) - require.Equal(t, "integer divide by zero (1:3)\n | 1 / (1 - 1)\n | ..^", err.Error()) + require.NoError(t, err) _, err = expr.Compile(`1 % 0`, expr.Env(env)) require.Error(t, err) + require.Equal(t, "integer divide by zero (1:3)\n | 1 % 0\n | ..^", err.Error()) } func TestIssue154(t *testing.T) { diff --git a/optimizer/fold.go b/optimizer/fold.go index 4d5dc59fa..d6706ee03 100644 --- a/optimizer/fold.go +++ b/optimizer/fold.go @@ -32,48 +32,141 @@ func (fold *fold) Visit(node *Node) { if i, ok := n.Node.(*IntegerNode); ok { patchWithType(&IntegerNode{Value: -i.Value}, n.Node.Type()) } + if i, ok := n.Node.(*FloatNode); ok { + patchWithType(&FloatNode{Value: -i.Value}, n.Node.Type()) + } case "+": if i, ok := n.Node.(*IntegerNode); ok { patchWithType(&IntegerNode{Value: i.Value}, n.Node.Type()) } + if i, ok := n.Node.(*FloatNode); ok { + patchWithType(&FloatNode{Value: i.Value}, n.Node.Type()) + } } case *BinaryNode: switch n.Operator { case "+": - if a, ok := n.Left.(*IntegerNode); ok { - if b, ok := n.Right.(*IntegerNode); ok { + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { patchWithType(&IntegerNode{Value: a.Value + b.Value}, a.Type()) } } - if a, ok := n.Left.(*StringNode); ok { - if b, ok := n.Right.(*StringNode); ok { + { + a := toInteger(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value + b.Value}, a.Type()) + } + } + { + a := toString(n.Left) + b := toString(n.Right) + if a != nil && b != nil { patch(&StringNode{Value: a.Value + b.Value}) } } case "-": - if a, ok := n.Left.(*IntegerNode); ok { - if b, ok := n.Right.(*IntegerNode); ok { + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { patchWithType(&IntegerNode{Value: a.Value - b.Value}, a.Type()) } } + { + a := toInteger(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value - b.Value}, a.Type()) + } + } case "*": - if a, ok := n.Left.(*IntegerNode); ok { - if b, ok := n.Right.(*IntegerNode); ok { + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { patchWithType(&IntegerNode{Value: a.Value * b.Value}, a.Type()) } } + { + a := toInteger(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value * b.Value}, a.Type()) + } + } case "/": - if a, ok := n.Left.(*IntegerNode); ok { - if b, ok := n.Right.(*IntegerNode); ok { - if b.Value == 0 { - fold.err = &file.Error{ - Location: (*node).Location(), - Message: "integer divide by zero", - } - return - } - patchWithType(&IntegerNode{Value: a.Value / b.Value}, a.Type()) + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}, a.Type()) + } + } + { + a := toInteger(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: a.Value / b.Value}, a.Type()) } } case "%": @@ -90,9 +183,32 @@ func (fold *fold) Visit(node *Node) { } } case "**", "^": - if a, ok := n.Left.(*IntegerNode); ok { - if b, ok := n.Right.(*IntegerNode); ok { - patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}, a.Type()) + } + } + { + a := toInteger(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}, a.Type()) + } + } + { + a := toFloat(n.Left) + b := toFloat(n.Right) + if a != nil && b != nil { + patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}, a.Type()) } } } @@ -145,3 +261,27 @@ func (fold *fold) Visit(node *Node) { } } } + +func toString(n Node) *StringNode { + switch a := n.(type) { + case *StringNode: + return a + } + return nil +} + +func toInteger(n Node) *IntegerNode { + switch a := n.(type) { + case *IntegerNode: + return a + } + return nil +} + +func toFloat(n Node) *FloatNode { + switch a := n.(type) { + case *FloatNode: + return a + } + return nil +}