Skip to content

Commit

Permalink
Merge pull request apache#17 from dato-code/builtin_symbol_and_docume…
Browse files Browse the repository at this point in the history
…ntation

Builtin symbol and documentation
  • Loading branch information
Jay Gu committed Apr 4, 2016
2 parents 4ad485c + 9ceaef9 commit c483b3d
Show file tree
Hide file tree
Showing 18 changed files with 720 additions and 52 deletions.
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@

from . import torch
from . import torch as th
from . import builtin_symbols

__version__ = base.__version__
9 changes: 5 additions & 4 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import libinfo

__all__ = ['MXNetError']
__LOGGER__ = logging.getLogger(__name__)
#----------------------------
# library loading
#----------------------------
Expand All @@ -38,13 +39,13 @@ def _load_lib():
if os.path.exists(cuda_lib_path):
try:
lib = ctypes.cdll.LoadLibrary(cuda_lib_path)
logging.info("CUDA GPU support is activated.")
__LOGGER__.info("CUDA GPU support is activated.")
except Exception as e:
logging.warn("Fail loading CUDA library. Error: %s" % e)
logging.info("Please try adding the CUDA installation path to LD_LIBRARY_PATH. Running CPU only mode.")
__LOGGER__.warn("Fail loading CUDA library. Error: %s" % e)
__LOGGER__.info("Please try adding the CUDA installation path to LD_LIBRARY_PATH. Running CPU only mode.")
lib = ctypes.cdll.LoadLibrary(lib_path)
else:
logging.info("CUDA support is currently not available on this platform. Running CPU only mode.")
__LOGGER__.info("CUDA support is currently not available on this platform. Running CPU only mode.")
lib = ctypes.cdll.LoadLibrary(lib_path)
# DMatrix functions
lib.MXGetLastError.restype = ctypes.c_char_p
Expand Down
7 changes: 7 additions & 0 deletions python/mxnet/builtin_symbols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from . import symbol_alexnet
from . import symbol_googlenet
from . import symbol_vgg
from . import symbol_inception_v3
from . import symbol_inception_bn
from . import symbol_inception_bn_full
from . import symbol_inception_bn_28_small
55 changes: 55 additions & 0 deletions python/mxnet/builtin_symbols/symbol_alexnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from .. import symbol

def get_symbol(num_classes = 1000):
"""
Return the "AlexNet" architecture for image classification
Parameters
----------
num_classes : int, optional
Number of classes in the ouptut layer.
References
----------
- Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet
classification with deep convolutional neural networks." Advances in neural
information processing systems. 2012.
"""
input_data = symbol.Variable(name="data")
# stage 1
conv1 = symbol.Convolution(
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)
relu1 = symbol.Activation(data=conv1, act_type="relu")
pool1 = symbol.Pooling(
data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2))
lrn1 = symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 2
conv2 = symbol.Convolution(
data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256)
relu2 = symbol.Activation(data=conv2, act_type="relu")
pool2 = symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max")
lrn2 = symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 3
conv3 = symbol.Convolution(
data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu3 = symbol.Activation(data=conv3, act_type="relu")
conv4 = symbol.Convolution(
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu4 = symbol.Activation(data=conv4, act_type="relu")
conv5 = symbol.Convolution(
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = symbol.Activation(data=conv5, act_type="relu")
pool3 = symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 4
flatten = symbol.Flatten(data=pool3)
fc1 = symbol.FullyConnected(data=flatten, num_hidden=4096)
relu6 = symbol.Activation(data=fc1, act_type="relu")
dropout1 = symbol.Dropout(data=relu6, p=0.5)
# stage 5
fc2 = symbol.FullyConnected(data=dropout1, num_hidden=4096)
relu7 = symbol.Activation(data=fc2, act_type="relu")
dropout2 = symbol.Dropout(data=relu7, p=0.5)
# stage 6
fc3 = symbol.FullyConnected(data=dropout2, num_hidden=num_classes)
softmax = symbol.SoftmaxOutput(data=fc3, name='softmax')
return softmax
61 changes: 61 additions & 0 deletions python/mxnet/builtin_symbols/symbol_googlenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .. import symbol

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
act = symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd5x5r = ConvFactory(data=data, num_filter=num_d5x5red, kernel=(1, 1), name=('%s_5x5' % name), suffix='_reduce')
cd5x5 = ConvFactory(data=cd5x5r, num_filter=num_d5x5, kernel=(5, 5), pad=(2, 2), name=('%s_5x5' % name))
# pool + proj
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name)
return concat

def get_symbol(num_classes = 1000):
"""
Return the "GoogLeNet" architecture for image classification
Parameters
----------
num_classes : int, optional
Number of classes in the ouptut layer.
References
----------
- Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir
Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. "Going deeper
with convolutions." arXiv preprint arXiv:1409.4842, 2014.
"""
data = sym.Variable("data")
conv1 = ConvFactory(data, 64, kernel=(7, 7), stride=(2,2), pad=(3, 3), name="conv1")
pool1 = sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max")
conv2 = ConvFactory(pool1, 64, kernel=(1, 1), stride=(1,1), name="conv2")
conv3 = ConvFactory(conv2, 192, kernel=(3, 3), stride=(1, 1), pad=(1,1), name="conv3")
pool3 = sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max")

in3a = InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name="in3a")
in3b = InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name="in3b")
pool4 = sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max")
in4a = InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name="in4a")
in4b = InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name="in4b")
in4c = InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name="in4c")
in4d = InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name="in4d")
in4e = InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name="in4e")
pool5 = sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max")
in5a = InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, name="in5a")
in5b = InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name="in5b")
pool6 = sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg")
flatten = sym.Flatten(data=pool6)
fc1 = sym.FullyConnected(data=flatten, num_hidden=num_classes)
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax')
return softmax
86 changes: 86 additions & 0 deletions python/mxnet/builtin_symbols/symbol_inception_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from .. import symbol

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))
act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat

def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name):
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name))
# concat
concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat

def get_symbol(num_classes=1000):
"""
Return the "BN-Inception" architecture for image classification
The network is suitable for images of size around 224 x 224
Parameters
----------
num_classes : int, optional
Number of classes in the ouptut layer.
References
----------
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep
network training by reducing internal covariate shift. arXiv preprint
arXiv:1502.03167, 2015.
"""

# data
data = symbol.Variable(name="data")
# stage 1
conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1')
pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max')
# stage 2
conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1), name='conv2red')
conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2')
pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max')
# stage 2
in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, "avg", 32, '3a')
in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b')
in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, '3c')
# stage 3
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a')
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b')
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c')
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, "avg", 128, '4d')
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e')
# stage 4
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a')
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b')
# global avg pooling
avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
# linear classifier
flatten = symbol.Flatten(data=avg, name='flatten')
fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1')
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax')
return softmax
63 changes: 63 additions & 0 deletions python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from .. import symbol

# Basic Conv + BN + ReLU factory
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"):
conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = symbol.BatchNorm(data=conv)
act = symbol.Activation(data = bn, act_type=act_type)
return act

# A Simple Downsampling Factory
def DownsampleFactory(data, ch_3x3):
# conv 3x3
conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))
# pool
pool = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max')
# concat
concat = symbol.Concat(*[conv, pool])
return concat

# A Simple module
def SimpleFactory(data, ch_1x1, ch_3x3):
# 1x1
conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)
# 3x3
conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)
#concat
concat = symbol.Concat(*[conv1x1, conv3x3])
return concat

def get_symbol(num_classes = 10):
"""
Return a simplified version of "BN-Inception" architecture for image classification
The network is suitable for images of size around 28 x 28
Parameters
----------
num_classes : int, optional
Number of classes in the ouptut layer.
References
----------
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep
network training by reducing internal covariate shift. arXiv preprint
arXiv:1502.03167, 2015.
"""
data = symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
in4e = DownsampleFactory(in4d, 96)
in5a = SimpleFactory(in4e, 176, 160)
in5b = SimpleFactory(in5a, 176, 160)
pool = symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool")
flatten = symbol.Flatten(data=pool, name="flatten1")
fc = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1")
softmax = symbol.SoftmaxOutput(data=fc, name="softmax")
return softmax
86 changes: 86 additions & 0 deletions python/mxnet/builtin_symbols/symbol_inception_bn_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from .. import symbol

def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''):
conv = symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix))
bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix))
act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix))
return act

def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name):
# 1x1
c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name))
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name))
# concat
concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name)
return concat

def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name):
# 3x3 reduce + 3x3
c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce')
c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name))
# double 3x3 reduce + double 3x3
cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce')
cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name))
cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name))
# pool + proj
pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name))
# concat
concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name)
return concat

def get_symbol(num_classes = 21841):
"""
Return a variant of "BN-Inception" architecture for image classification
The network is suitable for the full ImageNet dataset with 21841 classes
Parameters
----------
num_classes : int, optional
Number of classes in the ouptut layer.
References
----------
- Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep
network training by reducing internal covariate shift. arXiv preprint
arXiv:1502.03167, 2015.
"""

# data
data = symbol.Variable(name="data")
# stage 1
conv1 = ConvFactory(data=data, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1')
pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max')
# stage 2
conv2red = ConvFactory(data=pool1, num_filter=128, kernel=(1, 1), stride=(1, 1), name='conv2red')
conv2 = ConvFactory(data=conv2red, num_filter=288, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2')
pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max')
# stage 2
in3a = InceptionFactoryA(pool2, 96, 96, 96, 96, 144, "avg", 48, '3a')
in3b = InceptionFactoryA(in3a, 96, 96, 144, 96, 144, "avg", 96, '3b')
in3c = InceptionFactoryB(in3b, 192, 240, 96, 144, '3c')
# stage 3
in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a')
in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b')
in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c')
in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 96, "avg", 128, '4d')
in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e')
# stage 4
in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a')
in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b')
# global avg pooling
avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg')
# linear classifier
flatten = symbol.Flatten(data=avg, name='flatten')
fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1')
softmax = symbol.SoftmaxOutput(data=fc1, name='softmax')
return softmax
Loading

0 comments on commit c483b3d

Please sign in to comment.