Skip to content

Commit

Permalink
Fix constant folding for floats and ints
Browse files Browse the repository at this point in the history
  • Loading branch information
antonmedv committed Nov 8, 2022
1 parent 3d4c219 commit b26d4b7
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 23 deletions.
4 changes: 2 additions & 2 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,11 +1386,11 @@ func TestIssue138(t *testing.T) {
env := map[string]interface{}{}

_, err := expr.Compile(`1 / (1 - 1)`, expr.Env(env))
require.Error(t, err)
require.Equal(t, "integer divide by zero (1:3)\n | 1 / (1 - 1)\n | ..^", err.Error())
require.NoError(t, err)

_, err = expr.Compile(`1 % 0`, expr.Env(env))
require.Error(t, err)
require.Equal(t, "integer divide by zero (1:3)\n | 1 % 0\n | ..^", err.Error())
}

func TestIssue154(t *testing.T) {
Expand Down
182 changes: 161 additions & 21 deletions optimizer/fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,141 @@ func (fold *fold) Visit(node *Node) {
if i, ok := n.Node.(*IntegerNode); ok {
patchWithType(&IntegerNode{Value: -i.Value}, n.Node.Type())
}
if i, ok := n.Node.(*FloatNode); ok {
patchWithType(&FloatNode{Value: -i.Value}, n.Node.Type())
}
case "+":
if i, ok := n.Node.(*IntegerNode); ok {
patchWithType(&IntegerNode{Value: i.Value}, n.Node.Type())
}
if i, ok := n.Node.(*FloatNode); ok {
patchWithType(&FloatNode{Value: i.Value}, n.Node.Type())
}
}

case *BinaryNode:
switch n.Operator {
case "+":
if a, ok := n.Left.(*IntegerNode); ok {
if b, ok := n.Right.(*IntegerNode); ok {
{
a := toInteger(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&IntegerNode{Value: a.Value + b.Value}, a.Type())
}
}
if a, ok := n.Left.(*StringNode); ok {
if b, ok := n.Right.(*StringNode); ok {
{
a := toInteger(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value + b.Value}, a.Type())
}
}
{
a := toString(n.Left)
b := toString(n.Right)
if a != nil && b != nil {
patch(&StringNode{Value: a.Value + b.Value})
}
}
case "-":
if a, ok := n.Left.(*IntegerNode); ok {
if b, ok := n.Right.(*IntegerNode); ok {
{
a := toInteger(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&IntegerNode{Value: a.Value - b.Value}, a.Type())
}
}
{
a := toInteger(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value - b.Value}, a.Type())
}
}
case "*":
if a, ok := n.Left.(*IntegerNode); ok {
if b, ok := n.Right.(*IntegerNode); ok {
{
a := toInteger(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&IntegerNode{Value: a.Value * b.Value}, a.Type())
}
}
{
a := toInteger(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value * b.Value}, a.Type())
}
}
case "/":
if a, ok := n.Left.(*IntegerNode); ok {
if b, ok := n.Right.(*IntegerNode); ok {
if b.Value == 0 {
fold.err = &file.Error{
Location: (*node).Location(),
Message: "integer divide by zero",
}
return
}
patchWithType(&IntegerNode{Value: a.Value / b.Value}, a.Type())
{
a := toInteger(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}, a.Type())
}
}
{
a := toInteger(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: a.Value / b.Value}, a.Type())
}
}
case "%":
Expand All @@ -90,9 +183,32 @@ func (fold *fold) Visit(node *Node) {
}
}
case "**", "^":
if a, ok := n.Left.(*IntegerNode); ok {
if b, ok := n.Right.(*IntegerNode); ok {
patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))})
{
a := toInteger(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}, a.Type())
}
}
{
a := toInteger(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toInteger(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}, a.Type())
}
}
{
a := toFloat(n.Left)
b := toFloat(n.Right)
if a != nil && b != nil {
patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}, a.Type())
}
}
}
Expand Down Expand Up @@ -145,3 +261,27 @@ func (fold *fold) Visit(node *Node) {
}
}
}

func toString(n Node) *StringNode {
switch a := n.(type) {
case *StringNode:
return a
}
return nil
}

func toInteger(n Node) *IntegerNode {
switch a := n.(type) {
case *IntegerNode:
return a
}
return nil
}

func toFloat(n Node) *FloatNode {
switch a := n.(type) {
case *FloatNode:
return a
}
return nil
}

0 comments on commit b26d4b7

Please sign in to comment.