Skip to content

Commit

Permalink
nn: Fix reuse tensors (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 2, 2025
1 parent e4a4bd6 commit 3aac115
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 10 deletions.
20 changes: 12 additions & 8 deletions common/nn/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ func Neg(x *Tensor) *Tensor {
}

// Add returns the element-wise sum of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor.
func Add(x0, x1 *Tensor) *Tensor {
if len(x0.shape) < len(x1.shape) {
x0, x1 = x1, x0
}
for i := 0; i < len(x1.shape); i++ {
if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] {
panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor")
func Add(x0 *Tensor, x ...*Tensor) *Tensor {
output := x0
for _, x1 := range x {
if len(x0.shape) < len(x1.shape) {
x0, x1 = x1, x0
}
for i := 0; i < len(x1.shape); i++ {
if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] {
panic("the shape of the second tensor must be a suffix sequence of the shape of the first tensor")
}
}
output = apply(&add{}, output, x1)
}
return apply(&add{}, x0, x1)
return output
}

// Sub returns the element-wise difference of two tensors. The shape of the second tensor must be a suffix sequence of the shape of the first tensor.
Expand Down
125 changes: 124 additions & 1 deletion common/nn/op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func TestReshape(t *testing.T) {
assert.Equal(t, []float32{1, 1, 1, 1, 1, 1}, x.grad.data)
}

func TestReuse(t *testing.T) {
func TestReuseLeaf(t *testing.T) {
// x + x
x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
y := Add(x, x)
Expand All @@ -544,3 +544,126 @@ func TestReuse(t *testing.T) {
dx := numericalDiff(func(x *Tensor) *Tensor { return Add(x, x) }, x)
allClose(t, x.grad, dx)
}

func TestReuseNode(t *testing.T) {
// x^2 + x^2
x := NewVariable([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
temp := Pow(x, NewVariable([]float32{2}))
y := Add(temp, temp)
assert.Equal(t, []float32{2, 8, 18, 32, 50, 72}, y.data)

// Test gradient
y.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
temp := Pow(x, NewVariable([]float32{2}))
return Add(temp, temp)
}, x)
allClose(t, x.grad, dx)
}

func TestSphere(t *testing.T) {
// x^2 + y^2
x := NewScalar(1)
y := NewScalar(1)
z := Add(Mul(x, x), Mul(y, y))
assert.Equal(t, []float32{2}, z.data)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, x)
dy := numericalDiff(func(y *Tensor) *Tensor { return Add(Mul(x, x), Mul(y, y)) }, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}

func TestMatyas(t *testing.T) {
// 0.26 * (x^2 + y^2) - 0.48 * x * y
x := NewScalar(1)
y := NewScalar(1)
z := Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
assert.InDeltaSlice(t, []float32{0.04}, z.data, 1e-6)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
}, x)
dy := numericalDiff(func(y *Tensor) *Tensor {
return Sub(Mul(NewScalar(0.26), Add(Mul(x, x), Mul(y, y))), Mul(NewScalar(0.48), Mul(x, y)))
}, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}

func TestGoldsteinPrice(t *testing.T) {
// (1 + (x + y + 1)^2 * (19 - 14x + 3x^2 - 14y + 6xy + 3y^2)) * (30 + (2x - 3y)^2 * (18 - 32x + 12x^2 + 48y - 36xy + 27y^2))
x := NewScalar(1)
y := NewScalar(1)
z := Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
assert.InDeltaSlice(t, []float32{1876}, z.data, 1e-6)

// Test gradient
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor {
return Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
}, x)
dy := numericalDiff(func(y *Tensor) *Tensor {
return Mul(
Add(NewScalar(1), Mul(
Pow(Add(x, y, NewScalar(1)), NewScalar(2)), // (x + y + 1)^2
Add(
NewScalar(19), // 19
Mul(NewScalar(-14), x), // -14x
Mul(NewScalar(3), Pow(x, NewScalar(2))), // 3x^2
Mul(NewScalar(-14), y), // -14y
Mul(NewScalar(6), Mul(x, y)), // 6xy
Mul(NewScalar(3), Pow(y, NewScalar(2)))))), // 3y^2
Add(NewScalar(30), Mul(
Pow(Sub(Mul(NewScalar(2), x), Mul(NewScalar(3), y)), NewScalar(2)), // (2x - 3y)^2
Add(
NewScalar(18), // 18
Mul(NewScalar(-32), x), // -32x
Mul(NewScalar(12), Pow(x, NewScalar(2))), // 12x^2
Mul(NewScalar(48), y), // 48y
Mul(NewScalar(-36), Mul(x, y)), // -36xy
Mul(NewScalar(27), Pow(y, NewScalar(2))))))) // 27y^2
}, y)
allClose(t, x.grad, dx)
allClose(t, y.grad, dy)
}
5 changes: 4 additions & 1 deletion common/nn/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package nn
import (
"fmt"
"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/google/uuid"
"github.com/zhenghaoz/gorse/base/floats"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -209,6 +210,7 @@ func (t *Tensor) String() string {
func (t *Tensor) Backward() {
t.grad = Ones(t.shape...)
ops := []op{t.op}
seen := mapset.NewSet[op](t.op)
for len(ops) > 0 {
op := ops[0]
ops = ops[1:]
Expand All @@ -225,8 +227,9 @@ func (t *Tensor) Backward() {
} else {
inputs[i].grad.add(grads[i])
}
if inputs[i].op != nil {
if inputs[i].op != nil && !seen.Contains(inputs[i].op) {
ops = append(ops, inputs[i].op)
seen.Add(inputs[i].op)
} else if !inputs[i].requireGrad {
// Clear gradient if the leaf tensor does not require gradient
//inputs[i].grad = nil
Expand Down

0 comments on commit 3aac115

Please sign in to comment.