Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn: Fix MNIST #916

Merged
merged 6 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
uses: actions/checkout@v2

- name: Test
run: go test -timeout 20m -v ./... -coverprofile=coverage.txt -covermode=atomic -coverpkg=./...
run: go test -timeout 30m -v ./... -coverprofile=coverage.txt -covermode=atomic -coverpkg=./...
env:
# MySQL
MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mysql.ports[3306] }})/
Expand Down
4 changes: 3 additions & 1 deletion common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package nn

import "github.com/chewxy/math32"

type Layer interface {
Parameters() []*Tensor
Forward(x *Tensor) *Tensor
Expand All @@ -28,7 +30,7 @@ type linearLayer struct {

func NewLinear(in, out int) Layer {
return &linearLayer{
w: Rand(in, out).RequireGrad(),
w: Normal(0, 1.0/math32.Sqrt(float32(in)), in, out).RequireGrad(),
b: Zeros(out).RequireGrad(),
}
}
Expand Down
127 changes: 121 additions & 6 deletions common/nn/nn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
package nn

import (
"bufio"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"testing"

"github.com/chewxy/math32"
"github.com/samber/lo"
"github.com/schollz/progressbar/v3"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/common/dataset"
"github.com/zhenghaoz/gorse/common/util"
"os"
"path/filepath"
"testing"
)

func TestLinearRegression(t *testing.T) {
Expand All @@ -47,9 +54,9 @@ func TestLinearRegression(t *testing.T) {
}

assert.Equal(t, []int{1, 1}, w.shape)
assert.InDelta(t, float64(2), w.data[0], 0.5)
assert.InDelta(t, float64(2), w.data[0], 0.6)
assert.Equal(t, []int{1}, b.shape)
assert.InDelta(t, float64(5), b.data[0], 0.5)
assert.InDelta(t, float64(5), b.data[0], 0.6)
}

func TestNeuralNetwork(t *testing.T) {
Expand All @@ -76,7 +83,7 @@ func TestNeuralNetwork(t *testing.T) {
optimizer.Step()
l = loss.data[0]
}
assert.InDelta(t, float64(0), l, 0.1)
assert.InDelta(t, float64(0), l, 0.2)
}

func iris() (*Tensor, *Tensor, error) {
Expand Down Expand Up @@ -139,3 +146,111 @@ func TestIris(t *testing.T) {
}
assert.InDelta(t, float32(0), l, 0.1)
}

func mnist() (lo.Tuple2[*Tensor, *Tensor], lo.Tuple2[*Tensor, *Tensor], error) {
var train, test lo.Tuple2[*Tensor, *Tensor]
// Download and unzip dataset
path, err := dataset.DownloadAndUnzip("mnist")
if err != nil {
return train, test, err
}
// Open dataset
train.A, train.B, err = openMNISTFile(filepath.Join(path, "train.libfm"))
if err != nil {
return train, test, err
}
test.A, test.B, err = openMNISTFile(filepath.Join(path, "test.libfm"))
if err != nil {
return train, test, err
}
return train, test, nil
}

func openMNISTFile(path string) (*Tensor, *Tensor, error) {
// Open file
f, err := os.Open(path)
if err != nil {
return nil, nil, err
}
defer f.Close()
// Read data line by line
var (
images []float32
labels []float32
)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
splits := strings.Split(line, " ")
// Parse label
label, err := util.ParseFloat[float32](splits[0])
if err != nil {
return nil, nil, err
}
labels = append(labels, label)
// Parse image
image := make([]float32, 784)
for _, split := range splits[1:] {
kv := strings.Split(split, ":")
index, err := strconv.Atoi(kv[0])
if err != nil {
return nil, nil, err
}
value, err := util.ParseFloat[float32](kv[1])
if err != nil {
return nil, nil, err
}
image[index] = value
}
images = append(images, image...)
}
return NewTensor(images, len(labels), 784), NewTensor(labels, len(labels)), nil
}

func TestMNIST(t *testing.T) {
train, test, err := mnist()
assert.NoError(t, err)

model := NewSequential(
NewLinear(784, 1000),
NewReLU(),
NewLinear(1000, 10),
)
optimizer := NewAdam(model.Parameters(), 0.001)

var (
sumLoss float32
batchSize = 1000
)
for i := 0; i < 3; i++ {
sumLoss = 0
bar := progressbar.Default(int64(train.A.shape[0]), fmt.Sprintf("Epoch %v/%v", i+1, 3))
for j := 0; j < train.A.shape[0]; j += batchSize {
xBatch := train.A.Slice(j, j+batchSize)
yBatch := train.B.Slice(j, j+batchSize)

yPred := model.Forward(xBatch)
loss := SoftmaxCrossEntropy(yPred, yBatch)

optimizer.ZeroGrad()
loss.Backward()

optimizer.Step()
sumLoss += loss.data[0]
bar.Add(batchSize)
}
sumLoss /= float32(train.A.shape[0] / batchSize)
bar.Finish()
}
assert.Less(t, sumLoss, float32(0.4))

testPred := model.Forward(test.A)
var precision float32
for i, gt := range test.B.data {
if testPred.Slice(i, i+1).argmax()[1] == int(gt) {
precision += 1
}
}
precision /= float32(len(test.B.data))
assert.Greater(t, float64(precision), 0.92)
}
4 changes: 2 additions & 2 deletions common/nn/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,8 @@ func (r *relu) forward(inputs ...*Tensor) *Tensor {
}

func (r *relu) backward(dy *Tensor) []*Tensor {
dx := dy.clone()
dx.maximum(NewScalar(0))
x := r.inputs[0]
dx := x.clone().gt(NewScalar(0)).mul(dy)
return []*Tensor{dx}
}

Expand Down
3 changes: 0 additions & 3 deletions common/nn/op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ func TestDiv(t *testing.T) {
assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 0.75, 4.0 / 5.0, 5.0 / 6.0, 6.0 / 7.0}, z.data, 1e-6)

// Test gradient
x = Rand(2, 3).RequireGrad()
y = Rand(2, 3).RequireGrad()
z = Div(x, y)
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor { return Div(x, y) }, x)
allClose(t, x.grad, dx)
Expand Down
64 changes: 56 additions & 8 deletions common/nn/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ package nn
import (
"container/heap"
"fmt"
"math"
"math/rand"
"strings"

"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/base/floats"
"golang.org/x/exp/slices"
"math"
"math/rand"
"strings"
)

type Tensor struct {
Expand Down Expand Up @@ -94,6 +95,21 @@ func Rand(shape ...int) *Tensor {
}
}

func Normal(mean, std float32, shape ...int) *Tensor {
n := 1
for _, s := range shape {
n *= s
}
data := make([]float32, n)
for i := range data {
data[i] = float32(rand.NormFloat64())*std + mean
}
return &Tensor{
data: data,
shape: shape,
}
}

// Ones creates a tensor filled with ones.
func Ones(shape ...int) *Tensor {
n := 1
Expand Down Expand Up @@ -590,6 +606,27 @@ func (t *Tensor) maximum(other *Tensor) {
}
}

func (t *Tensor) gt(other *Tensor) *Tensor {
if other.IsScalar() {
for i := range t.data {
if t.data[i] > other.data[0] {
t.data[i] = 1
} else {
t.data[i] = 0
}
}
} else {
for i := range t.data {
if t.data[i] > other.data[i] {
t.data[i] = 1
} else {
t.data[i] = 0
}
}
}
return t
}

func (t *Tensor) transpose() *Tensor {
if len(t.shape) < 2 {
panic("transpose requires at least 2-D tensor")
Expand Down Expand Up @@ -694,13 +731,24 @@ func (t *Tensor) sum(axis int, keepDim bool) *Tensor {
}
}

func (t *Tensor) hasNaN() bool {
for i := range t.data {
if math32.IsNaN(t.data[i]) {
return true
func (t *Tensor) argmax() []int {
if len(t.data) == 0 {
return nil
}
maxValue := t.data[0]
maxIndex := 0
for i := 1; i < len(t.data); i++ {
if t.data[i] > maxValue {
maxValue = t.data[i]
maxIndex = i
}
}
return false
indices := make([]int, len(t.shape))
for i := len(t.shape) - 1; i >= 0; i-- {
indices[i] = maxIndex % t.shape[i]
maxIndex /= t.shape[i]
}
return indices
}

func NormalInit(t *Tensor, mean, std float32) {
Expand Down
13 changes: 6 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de
github.com/benhoyt/goawk v1.20.0
github.com/bits-and-blooms/bitset v1.2.1
github.com/chewxy/math32 v1.10.1
github.com/chewxy/math32 v1.11.1
github.com/coreos/go-oidc/v3 v3.11.0
github.com/deckarep/golang-set/v2 v2.3.1
github.com/emicklei/go-restful-openapi/v2 v2.9.0
Expand Down Expand Up @@ -40,7 +40,7 @@ require (
github.com/redis/go-redis/extra/redisotel/v9 v9.5.3
github.com/redis/go-redis/v9 v9.7.0
github.com/samber/lo v1.38.1
github.com/schollz/progressbar/v3 v3.9.0
github.com/schollz/progressbar/v3 v3.17.1
github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1
github.com/spf13/cobra v1.5.0
github.com/spf13/pflag v1.0.5
Expand Down Expand Up @@ -126,8 +126,7 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.16 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
Expand All @@ -145,7 +144,7 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect
github.com/redis/go-redis/extra/rediscmd/v9 v9.5.3 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.3.4 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
Expand All @@ -165,8 +164,8 @@ require (
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/term v0.25.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect
Expand Down
Loading
Loading