-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
23 lines (21 loc) · 925 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
import tensorflow as tf
def gen_mnist_model():
'TODO: docstring'
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(filters=20,
kernel_size=5),
tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
tf.keras.layers.Conv2D(filters=50,
kernel_size=5),
tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
def gen_mnist_iterator(x, y, bs):
'TODO: docstring'
x = (x / 255.0).astype(np.float32)[..., tf.newaxis]
y = tf.one_hot(y, 10)
ds = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(x.shape[0]).repeat().batch(bs).prefetch(1)
return ds.make_one_shot_iterator()