Skip to content

Commit

Permalink
nn: Fix Adam optimizer (#918)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 6, 2025
1 parent 515c3c9 commit 21a724c
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 140 deletions.
6 changes: 3 additions & 3 deletions common/nn/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Add(x0 *Tensor, x ...*Tensor) *Tensor {
output := x0
for _, x1 := range x {
if len(x0.shape) < len(x1.shape) {
x0, x1 = x1, x0
output, x1 = x1, output
}
for i := 0; i < len(x1.shape); i++ {
if x0.shape[len(x0.shape)-len(x1.shape)+i] != x1.shape[i] {
Expand Down Expand Up @@ -214,7 +214,7 @@ func SoftmaxCrossEntropy(x, y *Tensor) *Tensor {
//
// (1 + target) * math32.Log(1+math32.Exp(-prediction)) / 2 + (1 - target) * math32.Log(1+math32.Exp(prediction)) / 2
func BCEWithLogits(target, prediction *Tensor) *Tensor {
return Add(
return Mean(Add(
Div(
Mul(
Add(NewScalar(1), target),
Expand All @@ -224,5 +224,5 @@ func BCEWithLogits(target, prediction *Tensor) *Tensor {
Mul(
Sub(NewScalar(1), target),
Log(Add(NewScalar(1), Exp(prediction)))),
NewScalar(2)))
NewScalar(2))))
}
7 changes: 4 additions & 3 deletions common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ type LinearLayer struct {
}

func NewLinear(in, out int) Layer {
bound := 1.0 / math32.Sqrt(float32(in))
return &LinearLayer{
W: Normal(0, 1.0/math32.Sqrt(float32(in)), in, out).RequireGrad(),
B: Zeros(out).RequireGrad(),
W: Uniform(-bound, bound, in, out),
B: Zeros(out),
}
}

Expand Down Expand Up @@ -73,7 +74,7 @@ type EmbeddingLayer struct {
func NewEmbedding(n int, shape ...int) Layer {
wShape := append([]int{n}, shape...)
return &EmbeddingLayer{
W: Rand(wShape...),
W: Normal(0, 0.01, wShape...),
}
}

Expand Down
47 changes: 32 additions & 15 deletions common/nn/nn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/chewxy/math32"
"github.com/klauspost/cpuid/v2"
"github.com/samber/lo"
"github.com/schollz/progressbar/v3"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -208,7 +210,25 @@ func openMNISTFile(path string) (*Tensor, *Tensor, error) {
return NewTensor(images, len(labels), 784), NewTensor(labels, len(labels)), nil
}

func accuracy(prediction, target *Tensor) float32 {
var precision float32
for i, gt := range target.data {
if prediction.Slice(i, i+1).argmax()[1] == int(gt) {
precision += 1
}
}
precision /= float32(len(target.data))
return precision
}

func TestMNIST(t *testing.T) {
if cpuid.CPU.VendorString != "Apple" && !cpuid.CPU.Supports(cpuid.AVX512F, cpuid.AVX512DQ) {
// Since the test takes a long time, we run the test only in development environment.
// 1. Mac with Apple Silicon.
// 2. x86 CPU with AVX512 support.
t.Skip("Skip test on non-development environment.")
}

train, test, err := mnist()
assert.NoError(t, err)

Expand All @@ -219,13 +239,14 @@ func TestMNIST(t *testing.T) {
)
optimizer := NewAdam(model.Parameters(), 0.001)

var (
sumLoss float32
const (
batchSize = 1000
numEpoch = 5
)
for i := 0; i < 3; i++ {
sumLoss = 0
bar := progressbar.Default(int64(train.A.shape[0]), fmt.Sprintf("Epoch %v/%v", i+1, 3))
for i := 0; i < numEpoch; i++ {
startTime := time.Now()
sumLoss, sumAcc := float32(0), float32(0)
bar := progressbar.Default(int64(train.A.shape[0]), fmt.Sprintf("Epoch %v/%v", i+1, numEpoch))
for j := 0; j < train.A.shape[0]; j += batchSize {
xBatch := train.A.Slice(j, j+batchSize)
yBatch := train.B.Slice(j, j+batchSize)
Expand All @@ -238,22 +259,18 @@ func TestMNIST(t *testing.T) {

optimizer.Step()
sumLoss += loss.data[0]
sumAcc += accuracy(yPred, yBatch)
bar.Add(batchSize)

Check failure on line 263 in common/nn/nn_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `bar.Add` is not checked (errcheck)
}
sumLoss /= float32(train.A.shape[0] / batchSize)
sumAcc /= float32(train.A.shape[0] / batchSize)
bar.Finish()

Check failure on line 267 in common/nn/nn_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `bar.Finish` is not checked (errcheck)
fmt.Println("Duration:", time.Since(startTime), "Loss:", sumLoss, "Accuracy:", sumAcc)
}
assert.Less(t, sumLoss, float32(0.4))

testPred := model.Forward(test.A)
var precision float32
for i, gt := range test.B.data {
if testPred.Slice(i, i+1).argmax()[1] == int(gt) {
precision += 1
}
}
precision /= float32(len(test.B.data))
assert.Greater(t, float64(precision), 0.92)
testAcc := accuracy(model.Forward(test.A), test.B)
fmt.Println("Test Accuracy:", testAcc)
assert.Greater(t, float64(testAcc), 0.96)
}

func spiral() (*Tensor, *Tensor, error) {
Expand Down
Loading

0 comments on commit 21a724c

Please sign in to comment.