-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnist_train.kojo
149 lines (124 loc) · 4.14 KB
/
mnist_train.kojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// #include /nn.kojo
// #include /plot.kojo
import ai.djl.basicdataset.cv.classification.Mnist
val seed = 40
initRandomGenerator(seed)
Engine.getInstance.setRandomSeed(seed)
cleari()
clearOutput()
ndScoped { use =>
val model = use(new MnistModel())
timeit("Training") {
model.train()
}
model.test()
model.save()
}
def dataset(usage: Dataset.Usage, mgr: NDManager) = {
val mnist =
Mnist.builder()
.optUsage(usage)
.setSampling(64, true)
.optManager(mgr)
.build()
mnist.prepare(new ProgressBar())
mnist
}
class MnistModel extends AutoCloseable {
def learningRate(e: Int) = e match {
case n if n <= 10 => 0.1f
case n if n <= 15 => 0.03f
case _ => 0.01f
}
val nm = ndMaker
val trainingSet = dataset(Dataset.Usage.TRAIN, nm)
val validateSet = dataset(Dataset.Usage.TEST, nm)
val hidden1 = 38
val hidden2 = 12
val w1 = nm.randomNormal(0, 0.1f, Shape(784, hidden1), DataType.FLOAT32)
val b1 = nm.zeros(Shape(hidden1))
val w2 = nm.randomNormal(0, 0.1f, Shape(hidden1, hidden2), DataType.FLOAT32)
val b2 = nm.zeros(Shape(hidden2))
val w3 = nm.randomNormal(0, 0.1f, Shape(hidden2, 10), DataType.FLOAT32)
val b3 = nm.zeros(Shape(10))
val params = new NDList(w1, b1, w2, b2, w3, b3).asScala
val softmax = new SoftmaxCrossEntropyLoss()
params.foreach { p =>
p.setRequiresGradient(true)
}
def modelFunction(x: NDArray): NDArray = {
val l1 = x.matMul(w1).add(b1)
val l1a = Activation.relu(l1)
val l2 = l1a.matMul(w2).add(b2)
val l2a = Activation.relu(l2)
l2a.matMul(w3).add(b3)
}
val numEpochs = 20
def train(): Unit = {
println("Training Started...")
val lossChart = new LiveChart(
"Loss Plot", "epoch", "loss", 0, numEpochs, 0, 1
)
for (epoch <- 1 to numEpochs) {
var eloss = 0f
trainingSet.getData(nm).asScala.foreach { batch0 =>
ndScoped { use =>
val batch = use(batch0)
val gc = gradientCollector
val x = batch.getData.head.reshape(Shape(-1, 784))
val y = batch.getLabels.head
val yPred = use(modelFunction(x))
val loss = use(softmax.evaluate(new NDList(y), new NDList(yPred)))
eloss = loss.getFloat()
gc.backward(loss)
gc.close()
params.foreach { p =>
p.subi(p.getGradient.mul(learningRate(epoch)))
p.zeroGradients()
}
}
}
println(s"[$epoch] Loss -- $eloss")
lossChart.update(epoch, eloss)
}
println("Training Done")
}
def test() {
println("Determining accuracy on the test set")
var total = 0l
var totalGood = 0l
validateSet.getData(nm).asScala.foreach { batch0 =>
ndScoped { use =>
val batch = use(batch0)
val x = batch.getData.head.reshape(Shape(-1, 784))
val y = batch.getLabels.head
val yPred = use(modelFunction(x).softmax(1).argMax(1))
val matches = use(y.toType(DataType.INT64, false).eq(yPred))
total += matches.getShape.get(0)
totalGood += matches.countNonzero.getLong()
}
}
val acc = 1f * totalGood / total
println(acc)
}
def save() {
import java.io._
val modelFile = s"${kojoCtx.baseDir}/mnist.djl.model"
println(s"Saving model in file - $modelFile")
managed { use =>
val dos = use(new DataOutputStream(
new BufferedOutputStream(new FileOutputStream(modelFile))
))
dos.writeChar('P')
params.foreach { p =>
dos.write(p.encode())
}
}
}
def close() {
println("Closing remaining ndarrays...")
params.foreach(_.close())
nm.close()
println("Done")
}
}