From 964d42b6f3dbb393d4e5718e2daf349cb8e91d51 Mon Sep 17 00:00:00 2001 From: haijieg Date: Wed, 30 Mar 2016 15:18:10 -0700 Subject: [PATCH 1/8] Change default model logger to print progress --- python/mxnet/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 7458fd29e83b..ee3d6334137e 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -175,7 +175,8 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, - This function will inplace update the NDArrays in arg_parans and aux_states. """ if logger is None: - logger = logging + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) executor_manager = DataParallelExecutorManager(symbol=symbol, ctx=ctx, train_data=train_data, From ab2494dc5f85001fb4ce9898ecfca8dbde5b110d Mon Sep 17 00:00:00 2001 From: haijieg Date: Wed, 30 Mar 2016 15:18:31 -0700 Subject: [PATCH 2/8] Add data_name, label_name to sframeiter --- python/mxnet/io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 9bd8098a67a1..faf2ce681505 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -368,7 +368,7 @@ def getpad(self): class SFrameIter(DataIter): - def __init__(self, sframe, data_field, label_field=None, batch_size=1): + def __init__(self, sframe, data_field, label_field=None, batch_size=1, data_name='data', label_name='softmax_label'): """ Iterator over from SFrame @@ -411,8 +411,8 @@ def __init__(self, sframe, data_field, label_field=None, batch_size=1): self.field_length = inferred_shape["field_length"] self.data_ndarray = array(np.zeros(self.data_shape)) self.label_ndarray = array(np.zeros(self.label_shape)) - self.data = _init_data(self.data_ndarray, allow_empty=False, default_name="data") - self.label = _init_data(self.label_ndarray, allow_empty=True, default_name="softmax_label") + self.data = _init_data(self.data_ndarray, allow_empty=False, default_name=data_name) + self.label = _init_data(self.label_ndarray, allow_empty=True, default_name=label_name) # size self.batch_size = batch_size self.data_size = len(sframe) From a6bb376b8c617798388762af66a2a094c7064890 Mon Sep 17 00:00:00 2001 From: haijieg Date: Thu, 31 Mar 2016 14:51:34 -0700 Subject: [PATCH 3/8] Change more logging to use module level logger --- python/mxnet/base.py | 9 +++++---- python/mxnet/callback.py | 23 ++++++++++++++--------- python/mxnet/initializer.py | 6 ++++-- python/mxnet/lr_scheduler.py | 4 +++- python/mxnet/misc.py | 3 ++- python/mxnet/model.py | 5 +++-- python/mxnet/monitor.py | 4 +++- 7 files changed, 34 insertions(+), 20 deletions(-) diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 11f49b18ec71..e1e6b1e9e005 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -12,6 +12,7 @@ from . import libinfo __all__ = ['MXNetError'] +__LOGGER__ = logging.getLogger(__name__) #---------------------------- # library loading #---------------------------- @@ -38,13 +39,13 @@ def _load_lib(): if os.path.exists(cuda_lib_path): try: lib = ctypes.cdll.LoadLibrary(cuda_lib_path) - logging.info("CUDA GPU support is activated.") + __LOGGER__.info("CUDA GPU support is activated.") except Exception as e: - logging.warn("Fail loading CUDA library. Error: %s" % e) - logging.info("Please try adding the CUDA installation path to LD_LIBRARY_PATH. Running CPU only mode.") + __LOGGER__.warn("Fail loading CUDA library. Error: %s" % e) + __LOGGER__.info("Please try adding the CUDA installation path to LD_LIBRARY_PATH. Running CPU only mode.") lib = ctypes.cdll.LoadLibrary(lib_path) else: - logging.info("CUDA support is currently not available on this platform. Running CPU only mode.") + __LOGGER__.info("CUDA support is currently not available on this platform. Running CPU only mode.") lib = ctypes.cdll.LoadLibrary(lib_path) # DMatrix functions lib.MXGetLastError.restype = ctypes.c_char_p diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index c6f466b22269..18d1a7d4355f 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -6,7 +6,11 @@ import math import logging import time -from .model import save_checkpoint +from .model import save_checkpoint as _save_checkpoint + +__LOGGER__ = logging.getLogger(__name__) +__LOGGER__.setLevel(logging.INFO) +__LOGGER__.info('callback logger activated') def do_checkpoint(prefix): """Callback to checkpoint the model to prefix every epoch. @@ -19,11 +23,11 @@ def do_checkpoint(prefix): Returns ------- callback : function - The callback function that can be passed as iter_end_callback to fit. + The callback function that can be passed as epoch_end_callback to fit. """ def _callback(iter_no, sym, arg, aux): """The checkpoint function.""" - save_checkpoint(prefix, iter_no + 1, sym, arg, aux) + _save_checkpoint(prefix, iter_no + 1, sym, arg, aux) return _callback @@ -38,13 +42,13 @@ def log_train_metric(period): Returns ------- callback : function - The callback function that can be passed as iter_epoch_callback to fit. + The callback function that can be passed as epoch_end_callback to fit. """ def _callback(param): """The checkpoint function.""" if param.nbatch % period == 0: name, value = param.eval_metric.get() - logging.info('Iter[%d] Batch[%d] Train-%s=%f', + __LOGGER__.info('Iter[%d] Batch[%d] Train-%s=%f', param.epoch, param.nbatch, name, value) return _callback @@ -54,9 +58,10 @@ class Speedometer(object): Parameters ---------- - batch_size: int + batch_size : int batch_size of data - frequent: int + + frequent : int calcutaion frequent """ def __init__(self, batch_size, frequent=50): @@ -78,10 +83,10 @@ def __call__(self, param): speed = self.frequent * self.batch_size / (time.time() - self.tic) if param.eval_metric is not None: name, value = param.eval_metric.get() - logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", + __LOGGER__.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", param.epoch, count, speed, name, value) else: - logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", + __LOGGER__.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", param.epoch, count, speed) self.tic = time.time() else: diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 987c0fe104b3..ae00bfbdd55d 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -9,6 +9,8 @@ import logging import re +__LOGGER__ = logging.getLogger(__name__) + class Initializer(object): """Base class for Initializer.""" @@ -110,14 +112,14 @@ def __call__(self, name, arr): self.param[name].shape) arr[:] = self.param[name] if self.verbose: - logging.info('Initialized %s by loading', name) + __LOGGER__.info('Initialized %s by loading', name) else: assert self.default_init is not None, \ "Cannot Initialize %s. Not found in loaded param " + \ "and no default Initializer is provided." self.default_init(name, arr) if self.verbose: - logging.info('Initialized %s by default', name) + __LOGGER__.info('Initialized %s by default', name) class Mixed(object): """Initialize with mixed Initializer diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index e40e146a0af8..b4a75a64836c 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -4,6 +4,8 @@ """ import logging +__LOGGER__ = logging.getLogger() + class LRScheduler(object): """Base class of a learning rate scheduler""" def __init__(self): @@ -71,6 +73,6 @@ def __call__(self, num_update): if num_update > self.count + self.step: self.count += self.step self.base_lr *= self.factor - logging.info("Update[%d]: Change learning rate to %0.5e", + __LOGGER__.info("Update[%d]: Change learning rate to %0.5e", num_update, self.base_lr) return self.base_lr diff --git a/python/mxnet/misc.py b/python/mxnet/misc.py index 2d3ffc6e5abd..dd5022b94d46 100644 --- a/python/mxnet/misc.py +++ b/python/mxnet/misc.py @@ -3,6 +3,7 @@ import math import logging +__LOGGER__ = logging.getLogger(__name__) class LearningRateScheduler(object): """Base class of learning rate scheduler""" @@ -58,7 +59,7 @@ def __call__(self, iteration): lr = self.base_lr * math.pow(self.factor, int(iteration / self.step)) if lr != self.old_lr: self.old_lr = lr - logging.info("At Iteration [%d]: Swith to new learning rate %.5f", + __LOGGER__.info("At Iteration [%d]: Swith to new learning rate %.5f", iteration, lr) return lr diff --git a/python/mxnet/model.py b/python/mxnet/model.py index ee3d6334137e..a51cfc532b59 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -19,6 +19,8 @@ from .executor import DataParallelExecutorManager, _check_arguments, _load_data BASE_ESTIMATOR = object +__LOGGER__ = logging.getLogger(__name__) +__LOGGER__.setLevel(logging.INFO) try: from sklearn.base import BaseEstimator @@ -175,8 +177,7 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, - This function will inplace update the NDArrays in arg_parans and aux_states. """ if logger is None: - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) + logger = __LOGGER__ executor_manager = DataParallelExecutorManager(symbol=symbol, ctx=ctx, train_data=train_data, diff --git a/python/mxnet/monitor.py b/python/mxnet/monitor.py index 6fb84cc3ee55..12f98b3dc868 100644 --- a/python/mxnet/monitor.py +++ b/python/mxnet/monitor.py @@ -9,6 +9,8 @@ from math import sqrt import re +__LOGGER__ = logging.getLogger(__name__) + class Monitor(object): """Monitor outputs, weights, and gradients for debugging. @@ -113,7 +115,7 @@ def toc_print(self): """End collecting and print results""" res = self.toc() for n, k, v in res: - logging.info('Batch: {:7d} {:30s} {:s}'.format(n, k, v)) + __LOGGER__.info('Batch: {:7d} {:30s} {:s}'.format(n, k, v)) From 70a5bdad21f2477ec25b4f1c95704123a5afae15 Mon Sep 17 00:00:00 2001 From: haijieg Date: Thu, 31 Mar 2016 14:51:51 -0700 Subject: [PATCH 4/8] Add builtin symbols --- python/mxnet/builtin_symbols/__init__.py | 0 .../mxnet/builtin_symbols/symbol_alexnet.py | 55 ++++++ .../mxnet/builtin_symbols/symbol_googlenet.py | 61 ++++++ .../symbol_inception-bn-28-small.py | 63 +++++++ .../symbol_inception-bn-full.py | 86 +++++++++ .../builtin_symbols/symbol_inception-bn.py | 86 +++++++++ .../builtin_symbols/symbol_inception-v3.py | 176 ++++++++++++++++++ python/mxnet/builtin_symbols/symbol_vgg.py | 70 +++++++ 8 files changed, 597 insertions(+) create mode 100644 python/mxnet/builtin_symbols/__init__.py create mode 100644 python/mxnet/builtin_symbols/symbol_alexnet.py create mode 100644 python/mxnet/builtin_symbols/symbol_googlenet.py create mode 100644 python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py create mode 100644 python/mxnet/builtin_symbols/symbol_inception-bn-full.py create mode 100644 python/mxnet/builtin_symbols/symbol_inception-bn.py create mode 100644 python/mxnet/builtin_symbols/symbol_inception-v3.py create mode 100644 python/mxnet/builtin_symbols/symbol_vgg.py diff --git a/python/mxnet/builtin_symbols/__init__.py b/python/mxnet/builtin_symbols/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/mxnet/builtin_symbols/symbol_alexnet.py b/python/mxnet/builtin_symbols/symbol_alexnet.py new file mode 100644 index 000000000000..d255b2fc0ce3 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_alexnet.py @@ -0,0 +1,55 @@ +from ... import mxnet as mx + +def get_symbol(num_classes = 1000): + """ + Return the "AlexNet" architecture for image classification + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet + classification with deep convolutional neural networks." Advances in neural + information processing systems. 2012. + """ + input_data = mx.symbol.Variable(name="data") + # stage 1 + conv1 = mx.symbol.Convolution( + data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96) + relu1 = mx.symbol.Activation(data=conv1, act_type="relu") + pool1 = mx.symbol.Pooling( + data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2)) + lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5) + # stage 2 + conv2 = mx.symbol.Convolution( + data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256) + relu2 = mx.symbol.Activation(data=conv2, act_type="relu") + pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") + lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) + # stage 3 + conv3 = mx.symbol.Convolution( + data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384) + relu3 = mx.symbol.Activation(data=conv3, act_type="relu") + conv4 = mx.symbol.Convolution( + data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384) + relu4 = mx.symbol.Activation(data=conv4, act_type="relu") + conv5 = mx.symbol.Convolution( + data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256) + relu5 = mx.symbol.Activation(data=conv5, act_type="relu") + pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") + # stage 4 + flatten = mx.symbol.Flatten(data=pool3) + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096) + relu6 = mx.symbol.Activation(data=fc1, act_type="relu") + dropout1 = mx.symbol.Dropout(data=relu6, p=0.5) + # stage 5 + fc2 = mx.symbol.FullyConnected(data=dropout1, num_hidden=4096) + relu7 = mx.symbol.Activation(data=fc2, act_type="relu") + dropout2 = mx.symbol.Dropout(data=relu7, p=0.5) + # stage 6 + fc3 = mx.symbol.FullyConnected(data=dropout2, num_hidden=num_classes) + softmax = mx.symbol.SoftmaxOutput(data=fc3, name='softmax') + return softmax diff --git a/python/mxnet/builtin_symbols/symbol_googlenet.py b/python/mxnet/builtin_symbols/symbol_googlenet.py new file mode 100644 index 000000000000..4fcf3fa5f223 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_googlenet.py @@ -0,0 +1,61 @@ +from ... import mxnet as mx + +def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): + conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + act = mx.symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix)) + return act + +def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, pool, proj, name): + # 1x1 + c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) + # 3x3 reduce + 3x3 + c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') + c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) + # double 3x3 reduce + double 3x3 + cd5x5r = ConvFactory(data=data, num_filter=num_d5x5red, kernel=(1, 1), name=('%s_5x5' % name), suffix='_reduce') + cd5x5 = ConvFactory(data=cd5x5r, num_filter=num_d5x5, kernel=(5, 5), pad=(2, 2), name=('%s_5x5' % name)) + # pool + proj + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) + # concat + concat = mx.symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def get_symbol(num_classes = 1000): + """ + Return the "GoogLeNet" architecture for image classification + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir + Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. "Going deeper + with convolutions." arXiv preprint arXiv:1409.4842, 2014. + """ + data = mx.sym.Variable("data") + conv1 = ConvFactory(data, 64, kernel=(7, 7), stride=(2,2), pad=(3, 3), name="conv1") + pool1 = mx.sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max") + conv2 = ConvFactory(pool1, 64, kernel=(1, 1), stride=(1,1), name="conv2") + conv3 = ConvFactory(conv2, 192, kernel=(3, 3), stride=(1, 1), pad=(1,1), name="conv3") + pool3 = mx.sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max") + + in3a = InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name="in3a") + in3b = InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name="in3b") + pool4 = mx.sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max") + in4a = InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name="in4a") + in4b = InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name="in4b") + in4c = InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name="in4c") + in4d = InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name="in4d") + in4e = InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name="in4e") + pool5 = mx.sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max") + in5a = InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, name="in5a") + in5b = InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name="in5b") + pool6 = mx.sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg") + flatten = mx.sym.Flatten(data=pool6) + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes) + softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py b/python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py new file mode 100644 index 000000000000..d872559a8f30 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py @@ -0,0 +1,63 @@ +from ... import mxnet as mx + +# Basic Conv + BN + ReLU factory +def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"): + conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) + bn = mx.symbol.BatchNorm(data=conv) + act = mx.symbol.Activation(data = bn, act_type=act_type) + return act + +# A Simple Downsampling Factory +def DownsampleFactory(data, ch_3x3): + # conv 3x3 + conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1)) + # pool + pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max') + # concat + concat = mx.symbol.Concat(*[conv, pool]) + return concat + +# A Simple module +def SimpleFactory(data, ch_1x1, ch_3x3): + # 1x1 + conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1) + # 3x3 + conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3) + #concat + concat = mx.symbol.Concat(*[conv1x1, conv3x3]) + return concat + +def get_symbol(num_classes = 10): + """ + Return a simplified version of "BN-Inception" architecture for image classification + + The network is suitable for images of size around 28 x 28 + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep + network training by reducing internal covariate shift. arXiv preprint + arXiv:1502.03167, 2015. + """ + data = mx.symbol.Variable(name="data") + conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") + in3a = SimpleFactory(conv1, 32, 32) + in3b = SimpleFactory(in3a, 32, 48) + in3c = DownsampleFactory(in3b, 80) + in4a = SimpleFactory(in3c, 112, 48) + in4b = SimpleFactory(in4a, 96, 64) + in4c = SimpleFactory(in4b, 80, 80) + in4d = SimpleFactory(in4c, 48, 96) + in4e = DownsampleFactory(in4d, 96) + in5a = SimpleFactory(in4e, 176, 160) + in5b = SimpleFactory(in5a, 176, 160) + pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool") + flatten = mx.symbol.Flatten(data=pool, name="flatten1") + fc = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1") + softmax = mx.symbol.SoftmaxOutput(data=fc, name="softmax") + return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn-full.py b/python/mxnet/builtin_symbols/symbol_inception-bn-full.py new file mode 100644 index 000000000000..c8dbf56a3d93 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_inception-bn-full.py @@ -0,0 +1,86 @@ +from ... import mxnet as mx + +def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): + conv = mx.symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) + act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) + return act + +def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): + # 1x1 + c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) + # 3x3 reduce + 3x3 + c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') + c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) + # double 3x3 reduce + double 3x3 + cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') + cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) + cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) + # pool + proj + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) + # concat + concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): + # 3x3 reduce + 3x3 + c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') + c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name)) + # double 3x3 reduce + double 3x3 + cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') + cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) + cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) + # pool + proj + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) + # concat + concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def get_symbol(num_classes = 21841): + """ + Return a variant of "BN-Inception" architecture for image classification + + The network is suitable for the full ImageNet dataset with 21841 classes + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep + network training by reducing internal covariate shift. arXiv preprint + arXiv:1502.03167, 2015. + """ + + # data + data = mx.symbol.Variable(name="data") + # stage 1 + conv1 = ConvFactory(data=data, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') + pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') + # stage 2 + conv2red = ConvFactory(data=pool1, num_filter=128, kernel=(1, 1), stride=(1, 1), name='conv2red') + conv2 = ConvFactory(data=conv2red, num_filter=288, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') + pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') + # stage 2 + in3a = InceptionFactoryA(pool2, 96, 96, 96, 96, 144, "avg", 48, '3a') + in3b = InceptionFactoryA(in3a, 96, 96, 144, 96, 144, "avg", 96, '3b') + in3c = InceptionFactoryB(in3b, 192, 240, 96, 144, '3c') + # stage 3 + in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a') + in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b') + in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c') + in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 96, "avg", 128, '4d') + in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e') + # stage 4 + in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') + in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') + # global avg pooling + avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') + # linear classifier + flatten = mx.symbol.Flatten(data=avg, name='flatten') + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn.py b/python/mxnet/builtin_symbols/symbol_inception-bn.py new file mode 100644 index 000000000000..81eb1a97b7be --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_inception-bn.py @@ -0,0 +1,86 @@ +from ... import mxnet as mx + +def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): + conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) + act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) + return act + +def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): + # 1x1 + c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_1x1' % name)) + # 3x3 reduce + 3x3 + c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') + c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_3x3' % name)) + # double 3x3 reduce + double 3x3 + cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') + cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) + cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) + # pool + proj + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) + # concat + concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): + # 3x3 reduce + 3x3 + c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1), name=('%s_3x3' % name), suffix='_reduce') + c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_3x3' % name)) + # double 3x3 reduce + double 3x3 + cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1), name=('%s_double_3x3' % name), suffix='_reduce') + cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) + cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) + # pool + proj + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) + # concat + concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def get_symbol(num_classes=1000): + """ + Return the "BN-Inception" architecture for image classification + + The network is suitable for images of size around 224 x 224 + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep + network training by reducing internal covariate shift. arXiv preprint + arXiv:1502.03167, 2015. + """ + + # data + data = mx.symbol.Variable(name="data") + # stage 1 + conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') + pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') + # stage 2 + conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1), name='conv2red') + conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') + pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') + # stage 2 + in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, "avg", 32, '3a') + in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b') + in3c = InceptionFactoryB(in3b, 128, 160, 64, 96, '3c') + # stage 3 + in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, "avg", 128, '4a') + in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, "avg", 128, '4b') + in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, "avg", 128, '4c') + in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, "avg", 128, '4d') + in4e = InceptionFactoryB(in4d, 128, 192, 192, 256, '4e') + # stage 4 + in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') + in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') + # global avg pooling + avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') + # linear classifier + flatten = mx.symbol.Flatten(data=avg, name='flatten') + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception-v3.py b/python/mxnet/builtin_symbols/symbol_inception-v3.py new file mode 100644 index 000000000000..b5c744164604 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_inception-v3.py @@ -0,0 +1,176 @@ +from ... import mxnet as mx + +def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): + conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) + bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True) + act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) + return act + + +def Inception7A(data, + num_1x1, + num_3x3_red, num_3x3_1, num_3x3_2, + num_5x5_red, num_5x5, + pool, proj, + name): + tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name)) + tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv') + tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') + concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# First Downsample +def Inception7B(data, + num_3x3, + num_d3x3_red, num_d3x3_1, num_d3x3_2, + pool, + name): + tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name)) + tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) + concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7C(data, + num_1x1, + num_d7_red, num_d7_1, num_d7_2, + num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2') + tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7D(data, + num_3x3_red, num_3x3, + num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, + pool, + name): + tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + # concat + concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7E(data, + num_1x1, + num_d3_red, num_d3_1, num_d3_2, + num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv') + tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1') + tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') + tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# In[49]: + +def get_symbol(num_classes=1000): + """ + Return the "Inception-v3" architecture for image classification + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Christian Szegedy, Vicent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. + Rethinking the Inception Architecture for Computer Vision arXiv preprint arXiv:1512.00567, 2015. + """ + + data = mx.symbol.Variable(name="data") + # stage 1 + conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") + conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") + conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") + pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") + # stage 2 + conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") + conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") + pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") + # stage 3 + in3a = Inception7A(pool1, 64, + 64, 96, 96, + 48, 64, + "avg", 32, "mixed") + in3b = Inception7A(in3a, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_1") + in3c = Inception7A(in3b, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_2") + in3d = Inception7B(in3c, 384, + 64, 96, 96, + "max", "mixed_3") + # stage 4 + in4a = Inception7C(in3d, 192, + 128, 128, 192, + 128, 128, 128, 128, 192, + "avg", 192, "mixed_4") + in4b = Inception7C(in4a, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_5") + in4c = Inception7C(in4b, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_6") + in4d = Inception7C(in4c, 192, + 192, 192, 192, + 192, 192, 192, 192, 192, + "avg", 192, "mixed_7") + in4e = Inception7D(in4d, 192, 320, + 192, 192, 192, 192, + "max", "mixed_8") + # stage 5 + in5a = Inception7E(in4e, 320, + 384, 384, 384, + 448, 384, 384, 384, + "avg", 192, "mixed_9") + in5b = Inception7E(in5a, 320, + 384, 384, 384, + 448, 384, 384, 384, + "max", 192, "mixed_10") + # pool + pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") + flatten = mx.sym.Flatten(data=pool, name="flatten") + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + return softmax + diff --git a/python/mxnet/builtin_symbols/symbol_vgg.py b/python/mxnet/builtin_symbols/symbol_vgg.py new file mode 100644 index 000000000000..b2de48ef9435 --- /dev/null +++ b/python/mxnet/builtin_symbols/symbol_vgg.py @@ -0,0 +1,70 @@ +from ... import mxnet as mx + +def get_symbol(num_classes = 1000): + """ + Return the "VGG" architecture for image classification + + Parameters + ---------- + num_classes : int, optional + Number of classes in the ouptut layer. + + References + ---------- + - Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for + large-scale image recognition." arXiv preprint arXiv:1409.1556, 2014. + """ + + ## define alexnet + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution(data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + pool1 = mx.symbol.Pooling( + data=relu1_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1") + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + pool2 = mx.symbol.Pooling( + data=relu2_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1") + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2") + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + pool3 = mx.symbol.Pooling( + data=relu3_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2") + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + pool4 = mx.symbol.Pooling( + data=relu4_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1") + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2") + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + pool5 = mx.symbol.Pooling( + data=relu5_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + flatten = mx.symbol.Flatten(data=pool5, name="flatten") + fc6 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # output + fc8 = mx.symbol.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") + softmax = mx.symbol.SoftmaxOutput(data=fc8, name='softmax') + return softmax From 60043e7f1f1f00b3941b19f61b0212e218aaef74 Mon Sep 17 00:00:00 2001 From: haijieg Date: Thu, 31 Mar 2016 15:07:25 -0700 Subject: [PATCH 5/8] Rename illegal module name --- .../{symbol_inception-bn.py => symbol_inception_bn.py} | 0 ...l_inception-bn-28-small.py => symbol_inception_bn_28_small.py} | 0 .../{symbol_inception-bn-full.py => symbol_inception_bn_full.py} | 0 .../{symbol_inception-v3.py => symbol_inception_v3.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename python/mxnet/builtin_symbols/{symbol_inception-bn.py => symbol_inception_bn.py} (100%) rename python/mxnet/builtin_symbols/{symbol_inception-bn-28-small.py => symbol_inception_bn_28_small.py} (100%) rename python/mxnet/builtin_symbols/{symbol_inception-bn-full.py => symbol_inception_bn_full.py} (100%) rename python/mxnet/builtin_symbols/{symbol_inception-v3.py => symbol_inception_v3.py} (100%) diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn.py b/python/mxnet/builtin_symbols/symbol_inception_bn.py similarity index 100% rename from python/mxnet/builtin_symbols/symbol_inception-bn.py rename to python/mxnet/builtin_symbols/symbol_inception_bn.py diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py b/python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py similarity index 100% rename from python/mxnet/builtin_symbols/symbol_inception-bn-28-small.py rename to python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py diff --git a/python/mxnet/builtin_symbols/symbol_inception-bn-full.py b/python/mxnet/builtin_symbols/symbol_inception_bn_full.py similarity index 100% rename from python/mxnet/builtin_symbols/symbol_inception-bn-full.py rename to python/mxnet/builtin_symbols/symbol_inception_bn_full.py diff --git a/python/mxnet/builtin_symbols/symbol_inception-v3.py b/python/mxnet/builtin_symbols/symbol_inception_v3.py similarity index 100% rename from python/mxnet/builtin_symbols/symbol_inception-v3.py rename to python/mxnet/builtin_symbols/symbol_inception_v3.py From e4f673543e11116a459010cebfcf8142149090a9 Mon Sep 17 00:00:00 2001 From: haijieg Date: Thu, 31 Mar 2016 19:50:23 -0700 Subject: [PATCH 6/8] Fix builtin symbol import --- python/mxnet/__init__.py | 1 + python/mxnet/builtin_symbols/__init__.py | 7 ++ .../mxnet/builtin_symbols/symbol_alexnet.py | 52 +++++++-------- .../mxnet/builtin_symbols/symbol_googlenet.py | 28 ++++---- .../builtin_symbols/symbol_inception_bn.py | 30 ++++----- .../symbol_inception_bn_28_small.py | 24 +++---- .../symbol_inception_bn_full.py | 30 ++++----- .../builtin_symbols/symbol_inception_v3.py | 42 ++++++------ python/mxnet/builtin_symbols/symbol_vgg.py | 64 +++++++++---------- python/mxnet/callback.py | 1 - python/mxnet/model.py | 32 ++++++++-- 11 files changed, 170 insertions(+), 141 deletions(-) diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 01d0438e25a0..3d76bb962145 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -43,5 +43,6 @@ from . import torch from . import torch as th +from . import builtin_symbols __version__ = base.__version__ diff --git a/python/mxnet/builtin_symbols/__init__.py b/python/mxnet/builtin_symbols/__init__.py index e69de29bb2d1..a109ee68998c 100644 --- a/python/mxnet/builtin_symbols/__init__.py +++ b/python/mxnet/builtin_symbols/__init__.py @@ -0,0 +1,7 @@ +from . import symbol_alexnet +from . import symbol_googlenet +from . import symbol_vgg +from . import symbol_inception_v3 +from . import symbol_inception_bn +from . import symbol_inception_bn_full +from . import symbol_inception_bn_28_small diff --git a/python/mxnet/builtin_symbols/symbol_alexnet.py b/python/mxnet/builtin_symbols/symbol_alexnet.py index d255b2fc0ce3..5d61ef961196 100644 --- a/python/mxnet/builtin_symbols/symbol_alexnet.py +++ b/python/mxnet/builtin_symbols/symbol_alexnet.py @@ -1,4 +1,4 @@ -from ... import mxnet as mx +from .. import symbol def get_symbol(num_classes = 1000): """ @@ -15,41 +15,41 @@ def get_symbol(num_classes = 1000): classification with deep convolutional neural networks." Advances in neural information processing systems. 2012. """ - input_data = mx.symbol.Variable(name="data") + input_data = symbol.Variable(name="data") # stage 1 - conv1 = mx.symbol.Convolution( + conv1 = symbol.Convolution( data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96) - relu1 = mx.symbol.Activation(data=conv1, act_type="relu") - pool1 = mx.symbol.Pooling( + relu1 = symbol.Activation(data=conv1, act_type="relu") + pool1 = symbol.Pooling( data=relu1, pool_type="max", kernel=(3, 3), stride=(2,2)) - lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5) + lrn1 = symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5) # stage 2 - conv2 = mx.symbol.Convolution( + conv2 = symbol.Convolution( data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256) - relu2 = mx.symbol.Activation(data=conv2, act_type="relu") - pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") - lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) + relu2 = symbol.Activation(data=conv2, act_type="relu") + pool2 = symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") + lrn2 = symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) # stage 3 - conv3 = mx.symbol.Convolution( + conv3 = symbol.Convolution( data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384) - relu3 = mx.symbol.Activation(data=conv3, act_type="relu") - conv4 = mx.symbol.Convolution( + relu3 = symbol.Activation(data=conv3, act_type="relu") + conv4 = symbol.Convolution( data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384) - relu4 = mx.symbol.Activation(data=conv4, act_type="relu") - conv5 = mx.symbol.Convolution( + relu4 = symbol.Activation(data=conv4, act_type="relu") + conv5 = symbol.Convolution( data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256) - relu5 = mx.symbol.Activation(data=conv5, act_type="relu") - pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") + relu5 = symbol.Activation(data=conv5, act_type="relu") + pool3 = symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") # stage 4 - flatten = mx.symbol.Flatten(data=pool3) - fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096) - relu6 = mx.symbol.Activation(data=fc1, act_type="relu") - dropout1 = mx.symbol.Dropout(data=relu6, p=0.5) + flatten = symbol.Flatten(data=pool3) + fc1 = symbol.FullyConnected(data=flatten, num_hidden=4096) + relu6 = symbol.Activation(data=fc1, act_type="relu") + dropout1 = symbol.Dropout(data=relu6, p=0.5) # stage 5 - fc2 = mx.symbol.FullyConnected(data=dropout1, num_hidden=4096) - relu7 = mx.symbol.Activation(data=fc2, act_type="relu") - dropout2 = mx.symbol.Dropout(data=relu7, p=0.5) + fc2 = symbol.FullyConnected(data=dropout1, num_hidden=4096) + relu7 = symbol.Activation(data=fc2, act_type="relu") + dropout2 = symbol.Dropout(data=relu7, p=0.5) # stage 6 - fc3 = mx.symbol.FullyConnected(data=dropout2, num_hidden=num_classes) - softmax = mx.symbol.SoftmaxOutput(data=fc3, name='softmax') + fc3 = symbol.FullyConnected(data=dropout2, num_hidden=num_classes) + softmax = symbol.SoftmaxOutput(data=fc3, name='softmax') return softmax diff --git a/python/mxnet/builtin_symbols/symbol_googlenet.py b/python/mxnet/builtin_symbols/symbol_googlenet.py index 4fcf3fa5f223..dab6c7e51501 100644 --- a/python/mxnet/builtin_symbols/symbol_googlenet.py +++ b/python/mxnet/builtin_symbols/symbol_googlenet.py @@ -1,8 +1,8 @@ -from ... import mxnet as mx +from .. import symbol def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): - conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) - act = mx.symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix)) + conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + act = symbol.Activation(data=conv, act_type='relu', name='relu_%s%s' %(name, suffix)) return act def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, pool, proj, name): @@ -15,10 +15,10 @@ def InceptionFactory(data, num_1x1, num_3x3red, num_3x3, num_d5x5red, num_d5x5, cd5x5r = ConvFactory(data=data, num_filter=num_d5x5red, kernel=(1, 1), name=('%s_5x5' % name), suffix='_reduce') cd5x5 = ConvFactory(data=cd5x5r, num_filter=num_d5x5, kernel=(5, 5), pad=(2, 2), name=('%s_5x5' % name)) # pool + proj - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) # concat - concat = mx.symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name) + concat = symbol.Concat(*[c1x1, c3x3, cd5x5, cproj], name='ch_concat_%s_chconcat' % name) return concat def get_symbol(num_classes = 1000): @@ -36,26 +36,26 @@ def get_symbol(num_classes = 1000): Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. "Going deeper with convolutions." arXiv preprint arXiv:1409.4842, 2014. """ - data = mx.sym.Variable("data") + data = sym.Variable("data") conv1 = ConvFactory(data, 64, kernel=(7, 7), stride=(2,2), pad=(3, 3), name="conv1") - pool1 = mx.sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max") + pool1 = sym.Pooling(conv1, kernel=(3, 3), stride=(2, 2), pool_type="max") conv2 = ConvFactory(pool1, 64, kernel=(1, 1), stride=(1,1), name="conv2") conv3 = ConvFactory(conv2, 192, kernel=(3, 3), stride=(1, 1), pad=(1,1), name="conv3") - pool3 = mx.sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max") + pool3 = sym.Pooling(conv3, kernel=(3, 3), stride=(2, 2), pool_type="max") in3a = InceptionFactory(pool3, 64, 96, 128, 16, 32, "max", 32, name="in3a") in3b = InceptionFactory(in3a, 128, 128, 192, 32, 96, "max", 64, name="in3b") - pool4 = mx.sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max") + pool4 = sym.Pooling(in3b, kernel=(3, 3), stride=(2, 2), pool_type="max") in4a = InceptionFactory(pool4, 192, 96, 208, 16, 48, "max", 64, name="in4a") in4b = InceptionFactory(in4a, 160, 112, 224, 24, 64, "max", 64, name="in4b") in4c = InceptionFactory(in4b, 128, 128, 256, 24, 64, "max", 64, name="in4c") in4d = InceptionFactory(in4c, 112, 144, 288, 32, 64, "max", 64, name="in4d") in4e = InceptionFactory(in4d, 256, 160, 320, 32, 128, "max", 128, name="in4e") - pool5 = mx.sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max") + pool5 = sym.Pooling(in4e, kernel=(3, 3), stride=(2, 2), pool_type="max") in5a = InceptionFactory(pool5, 256, 160, 320, 32, 128, "max", 128, name="in5a") in5b = InceptionFactory(in5a, 384, 192, 384, 48, 128, "max", 128, name="in5b") - pool6 = mx.sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg") - flatten = mx.sym.Flatten(data=pool6) - fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=num_classes) - softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + pool6 = sym.Pooling(in5b, kernel=(7, 7), stride=(1,1), pool_type="avg") + flatten = sym.Flatten(data=pool6) + fc1 = sym.FullyConnected(data=flatten, num_hidden=num_classes) + softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception_bn.py b/python/mxnet/builtin_symbols/symbol_inception_bn.py index 81eb1a97b7be..b96a73561c14 100644 --- a/python/mxnet/builtin_symbols/symbol_inception_bn.py +++ b/python/mxnet/builtin_symbols/symbol_inception_bn.py @@ -1,9 +1,9 @@ -from ... import mxnet as mx +from .. import symbol def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): - conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) - bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) - act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) + conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) + act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) return act def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): @@ -17,10 +17,10 @@ def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) # pool + proj - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) # concat - concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) + concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) return concat def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): @@ -32,9 +32,9 @@ def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) # pool + proj - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) + pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) # concat - concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) + concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) return concat def get_symbol(num_classes=1000): @@ -56,14 +56,14 @@ def get_symbol(num_classes=1000): """ # data - data = mx.symbol.Variable(name="data") + data = symbol.Variable(name="data") # stage 1 conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') - pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') + pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') # stage 2 conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1), name='conv2red') conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') - pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') + pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') # stage 2 in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, "avg", 32, '3a') in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, "avg", 64, '3b') @@ -78,9 +78,9 @@ def get_symbol(num_classes=1000): in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') # global avg pooling - avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') + avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') # linear classifier - flatten = mx.symbol.Flatten(data=avg, name='flatten') - fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') - softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + flatten = symbol.Flatten(data=avg, name='flatten') + fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py b/python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py index d872559a8f30..aa19ca9ec175 100644 --- a/python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py +++ b/python/mxnet/builtin_symbols/symbol_inception_bn_28_small.py @@ -1,10 +1,10 @@ -from ... import mxnet as mx +from .. import symbol # Basic Conv + BN + ReLU factory def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"): - conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) - bn = mx.symbol.BatchNorm(data=conv) - act = mx.symbol.Activation(data = bn, act_type=act_type) + conv = symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) + bn = symbol.BatchNorm(data=conv) + act = symbol.Activation(data = bn, act_type=act_type) return act # A Simple Downsampling Factory @@ -12,9 +12,9 @@ def DownsampleFactory(data, ch_3x3): # conv 3x3 conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1)) # pool - pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max') + pool = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max') # concat - concat = mx.symbol.Concat(*[conv, pool]) + concat = symbol.Concat(*[conv, pool]) return concat # A Simple module @@ -24,7 +24,7 @@ def SimpleFactory(data, ch_1x1, ch_3x3): # 3x3 conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3) #concat - concat = mx.symbol.Concat(*[conv1x1, conv3x3]) + concat = symbol.Concat(*[conv1x1, conv3x3]) return concat def get_symbol(num_classes = 10): @@ -44,7 +44,7 @@ def get_symbol(num_classes = 10): network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167, 2015. """ - data = mx.symbol.Variable(name="data") + data = symbol.Variable(name="data") conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") in3a = SimpleFactory(conv1, 32, 32) in3b = SimpleFactory(in3a, 32, 48) @@ -56,8 +56,8 @@ def get_symbol(num_classes = 10): in4e = DownsampleFactory(in4d, 96) in5a = SimpleFactory(in4e, 176, 160) in5b = SimpleFactory(in5a, 176, 160) - pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool") - flatten = mx.symbol.Flatten(data=pool, name="flatten1") - fc = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1") - softmax = mx.symbol.SoftmaxOutput(data=fc, name="softmax") + pool = symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool") + flatten = symbol.Flatten(data=pool, name="flatten1") + fc = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name="fc1") + softmax = symbol.SoftmaxOutput(data=fc, name="softmax") return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception_bn_full.py b/python/mxnet/builtin_symbols/symbol_inception_bn_full.py index c8dbf56a3d93..44b773eb838c 100644 --- a/python/mxnet/builtin_symbols/symbol_inception_bn_full.py +++ b/python/mxnet/builtin_symbols/symbol_inception_bn_full.py @@ -1,9 +1,9 @@ -from ... import mxnet as mx +from .. import symbol def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), name=None, suffix=''): - conv = mx.symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) - bn = mx.symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) - act = mx.symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) + conv = symbol.Convolution(data=data, workspace=512, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, name='conv_%s%s' %(name, suffix)) + bn = symbol.BatchNorm(data=conv, name='bn_%s%s' %(name, suffix)) + act = symbol.Activation(data=bn, act_type='relu', name='relu_%s%s' %(name, suffix)) return act def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj, name): @@ -17,10 +17,10 @@ def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_0' % name)) cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), name=('%s_double_3x3_1' % name)) # pool + proj - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_proj' % name)) # concat - concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) + concat = symbol.Concat(*[c1x1, c3x3, cd3x3, cproj], name='ch_concat_%s_chconcat' % name) return concat def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): @@ -32,9 +32,9 @@ def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3, name): cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_double_3x3_0' % name)) cd3x3 = ConvFactory(data=cd3x3, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name=('%s_double_3x3_1' % name)) # pool + proj - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) + pooling = symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type="max", name=('max_pool_%s_pool' % name)) # concat - concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) + concat = symbol.Concat(*[c3x3, cd3x3, pooling], name='ch_concat_%s_chconcat' % name) return concat def get_symbol(num_classes = 21841): @@ -56,14 +56,14 @@ def get_symbol(num_classes = 21841): """ # data - data = mx.symbol.Variable(name="data") + data = symbol.Variable(name="data") # stage 1 conv1 = ConvFactory(data=data, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3), name='conv1') - pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') + pool1 = symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2), name='pool1', pool_type='max') # stage 2 conv2red = ConvFactory(data=pool1, num_filter=128, kernel=(1, 1), stride=(1, 1), name='conv2red') conv2 = ConvFactory(data=conv2red, num_filter=288, kernel=(3, 3), stride=(1, 1), pad=(1, 1), name='conv2') - pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') + pool2 = symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2), name='pool2', pool_type='max') # stage 2 in3a = InceptionFactoryA(pool2, 96, 96, 96, 96, 144, "avg", 48, '3a') in3b = InceptionFactoryA(in3a, 96, 96, 144, 96, 144, "avg", 96, '3b') @@ -78,9 +78,9 @@ def get_symbol(num_classes = 21841): in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, "avg", 128, '5a') in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, "max", 128, '5b') # global avg pooling - avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') + avg = symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1), name="global_pool", pool_type='avg') # linear classifier - flatten = mx.symbol.Flatten(data=avg, name='flatten') - fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') - softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + flatten = symbol.Flatten(data=avg, name='flatten') + fc1 = symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = symbol.SoftmaxOutput(data=fc1, name='softmax') return softmax diff --git a/python/mxnet/builtin_symbols/symbol_inception_v3.py b/python/mxnet/builtin_symbols/symbol_inception_v3.py index b5c744164604..ee4c763a5f19 100644 --- a/python/mxnet/builtin_symbols/symbol_inception_v3.py +++ b/python/mxnet/builtin_symbols/symbol_inception_v3.py @@ -1,9 +1,9 @@ -from ... import mxnet as mx +from .. import symbol as sym def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): - conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) - bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True) - act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) + conv = sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) + bn = sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True) + act = sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) return act @@ -19,9 +19,9 @@ def Inception7A(data, tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') - pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') - concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) + concat = sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) return concat # First Downsample @@ -34,8 +34,8 @@ def Inception7B(data, tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') - pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) - concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) + pooling = sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) + concat = sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) return concat def Inception7C(data, @@ -53,10 +53,10 @@ def Inception7C(data, tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') - pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') # concat - concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) + concat = sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) return concat def Inception7D(data, @@ -70,9 +70,9 @@ def Inception7D(data, tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') - pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) # concat - concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) + concat = sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) return concat def Inception7E(data, @@ -89,10 +89,10 @@ def Inception7E(data, tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') - pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + pooling = sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') # concat - concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) + concat = sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) return concat # In[49]: @@ -112,16 +112,16 @@ def get_symbol(num_classes=1000): Rethinking the Inception Architecture for Computer Vision arXiv preprint arXiv:1512.00567, 2015. """ - data = mx.symbol.Variable(name="data") + data = sym.Variable(name="data") # stage 1 conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") - pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") + pool = sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") # stage 2 conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") - pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") + pool1 = sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") # stage 3 in3a = Inception7A(pool1, 64, 64, 96, 96, @@ -168,9 +168,9 @@ def get_symbol(num_classes=1000): 448, 384, 384, 384, "max", 192, "mixed_10") # pool - pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") - flatten = mx.sym.Flatten(data=pool, name="flatten") - fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') - softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + pool = sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") + flatten = sym.Flatten(data=pool, name="flatten") + fc1 = sym.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = sym.SoftmaxOutput(data=fc1, name='softmax') return softmax diff --git a/python/mxnet/builtin_symbols/symbol_vgg.py b/python/mxnet/builtin_symbols/symbol_vgg.py index b2de48ef9435..e46be6bdb478 100644 --- a/python/mxnet/builtin_symbols/symbol_vgg.py +++ b/python/mxnet/builtin_symbols/symbol_vgg.py @@ -1,4 +1,4 @@ -from ... import mxnet as mx +from .. import symbol def get_symbol(num_classes = 1000): """ @@ -16,55 +16,55 @@ def get_symbol(num_classes = 1000): """ ## define alexnet - data = mx.symbol.Variable(name="data") + data = symbol.Variable(name="data") # group 1 - conv1_1 = mx.symbol.Convolution(data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") - relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") - pool1 = mx.symbol.Pooling( + conv1_1 = symbol.Convolution(data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") + relu1_1 = symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + pool1 = symbol.Pooling( data=relu1_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") # group 2 - conv2_1 = mx.symbol.Convolution( + conv2_1 = symbol.Convolution( data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1") - relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") - pool2 = mx.symbol.Pooling( + relu2_1 = symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + pool2 = symbol.Pooling( data=relu2_1, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") # group 3 - conv3_1 = mx.symbol.Convolution( + conv3_1 = symbol.Convolution( data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1") - relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") - conv3_2 = mx.symbol.Convolution( + relu3_1 = symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = symbol.Convolution( data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2") - relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") - pool3 = mx.symbol.Pooling( + relu3_2 = symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + pool3 = symbol.Pooling( data=relu3_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") # group 4 - conv4_1 = mx.symbol.Convolution( + conv4_1 = symbol.Convolution( data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") - relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") - conv4_2 = mx.symbol.Convolution( + relu4_1 = symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = symbol.Convolution( data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2") - relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") - pool4 = mx.symbol.Pooling( + relu4_2 = symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + pool4 = symbol.Pooling( data=relu4_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") # group 5 - conv5_1 = mx.symbol.Convolution( + conv5_1 = symbol.Convolution( data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1") - relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") - conv5_2 = mx.symbol.Convolution( + relu5_1 = symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = symbol.Convolution( data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2") - relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") - pool5 = mx.symbol.Pooling( + relu5_2 = symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + pool5 = symbol.Pooling( data=relu5_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") # group 6 - flatten = mx.symbol.Flatten(data=pool5, name="flatten") - fc6 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") - relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") - drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + flatten = symbol.Flatten(data=pool5, name="flatten") + fc6 = symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = symbol.Dropout(data=relu6, p=0.5, name="drop6") # group 7 - fc7 = mx.symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") - relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") - drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + fc7 = symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = symbol.Dropout(data=relu7, p=0.5, name="drop7") # output - fc8 = mx.symbol.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") - softmax = mx.symbol.SoftmaxOutput(data=fc8, name='softmax') + fc8 = symbol.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8") + softmax = symbol.SoftmaxOutput(data=fc8, name='softmax') return softmax diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 18d1a7d4355f..9c3016daaf5c 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -10,7 +10,6 @@ __LOGGER__ = logging.getLogger(__name__) __LOGGER__.setLevel(logging.INFO) -__LOGGER__.info('callback logger activated') def do_checkpoint(prefix): """Callback to checkpoint the model to prefix every epoch. diff --git a/python/mxnet/model.py b/python/mxnet/model.py index a51cfc532b59..5e8f0aea4985 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,6 +37,7 @@ def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type + Parameters ---------- kvstore : KVStore or str @@ -124,17 +125,18 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, logger=None, work_load_list=None, monitor=None): """Internal training function on multiple devices. This function will also work for single device as well. + Parameters ---------- symbol : Symbol The network configuration ctx : list of Context The training devices. - arg_names: list of str + arg_names : list of str Name of all arguments of the network. - param_names: list of str + param_names : list of str Name of all trainable parameters of the network. - aux_names: list of str + aux_names : list of str Name of all auxiliary states of the network. input_shape : tuple Shape of input data batch. @@ -295,6 +297,7 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params): """Checkpoint the model data into file. + Parameters ---------- prefix : str @@ -307,6 +310,7 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params): Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. + Notes ----- - ``prefix-symbol.json`` will be saved for symbol. @@ -322,6 +326,7 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params): def load_checkpoint(prefix, epoch): """Load model checkpoint from file. + Parameters ---------- prefix : str @@ -336,6 +341,7 @@ def load_checkpoint(prefix, epoch): Model parameter, dict of name to NDArray of net's weights. aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. + Notes ----- - symbol will be loaded from ``prefix-symbol.json``. @@ -357,6 +363,7 @@ def load_checkpoint(prefix, epoch): class FeedForward(BASE_ESTIMATOR): """Model class of MXNet for training and predicting feedforward nets. This class is designed for a single-data single output supervised network. + Parameters ---------- symbol : Symbol @@ -535,11 +542,13 @@ def _init_eval_iter(self, eval_data): def predict(self, X, num_batch=None): """Run the prediction, always only use one device. + Parameters ---------- X : mxnet.DataIter num_batch : int or None the number of batch to run. Go though all batches if None + Returns ------- y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs. @@ -580,6 +589,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None, work_load_list=None, monitor=None): """Fit the model. + Parameters ---------- X : DataIter, or numpy.ndarray/NDArray @@ -617,6 +627,8 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', work_load_list : float or int, optional The list of work load for different devices, in the same order as ctx + monitor : :class:`monitor.Monitor` + Monitoring weights and gradients. """ data = self._init_iter(X, y, is_train=True) @@ -665,13 +677,16 @@ def save(self, prefix, epoch=None): The advantage of load/save is the file is language agnostic. This means the file saved using save can be loaded by other language binding of mxnet. You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) + Parameters ---------- prefix : str Prefix of model name. + See Also -------- Symbol.load : the method to load the model back. + Notes ----- - ``prefix-symbol.json`` will be saved for symbol. @@ -685,6 +700,7 @@ def save(self, prefix, epoch=None): @staticmethod def load(prefix, epoch, ctx=None, **kwargs): """Load model checkpoint from file. + Parameters ---------- prefix : str @@ -695,10 +711,12 @@ def load(prefix, epoch, ctx=None, **kwargs): The device context of training and prediction. kwargs : dict other parameters for model, including num_epoch, optimizer and numpy_batch_size + Returns ------- model : FeedForward The loaded model that can be used for prediction. + Notes ----- - ``prefix-symbol.json`` will be saved for symbol. @@ -715,10 +733,11 @@ def create(symbol, X, y=None, ctx=None, num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, - kvstore='local', logger=None, work_load_list=None, **kwargs): + kvstore='local', logger=None, work_load_list=None, monitor=None, **kwargs): """Functional style to create a model. This function will be more consistent with functional languages such as R, where mutation is not allowed. + Parameters ---------- symbol : Symbol @@ -764,6 +783,8 @@ def create(symbol, X, y=None, ctx=None, work_load_list : list of float or int, optional The list of work load for different devices, in the same order as ctx + monitor : :class:`monitor.Monitor` + Monitoring weights and gradients. """ model = FeedForward(symbol, ctx=ctx, num_epoch=num_epoch, epoch_size=epoch_size, optimizer=optimizer, initializer=initializer, **kwargs) @@ -772,5 +793,6 @@ def create(symbol, X, y=None, ctx=None, batch_end_callback=batch_end_callback, kvstore=kvstore, logger=logger, - work_load_list=work_load_list) + work_load_list=work_load_list, + monitor=monitor) return model From 8c5bfa9b26df0df994ce2e41de7842d77755bdd1 Mon Sep 17 00:00:00 2001 From: haijieg Date: Fri, 1 Apr 2016 15:05:19 -0700 Subject: [PATCH 7/8] Fix SFrameIter doc --- python/mxnet/io.py | 49 +++++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index faf2ce681505..c07c7a2b8062 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -368,21 +368,42 @@ def getpad(self): class SFrameIter(DataIter): - def __init__(self, sframe, data_field, label_field=None, batch_size=1, data_name='data', label_name='softmax_label'): - """ - Iterator over from SFrame + """DataIter from SFrame + Provides DataIter interface for SFrame, a highly scalable columnar DataFrame. + The iterator can simultaneously iterate over multiple columns indicated by `data_field` and `label_field`. + `data_field` can refer either a single image typed column or multiple numerical columns (int, float or array). + `label_field` con only refer to a single numerical column (int, float or array). - Parameters - ---------- - sframe: SFrame object - source SFrame - data_field: string or list(string) - select fields of training data. For image or array type, only support string - label_field: string (optional) - label field in SFrame - batch_size: int (optional) - batch size - """ + Parameters + ---------- + sframe : SFrame object + source SFrame + data_field : string or list(string) + data fields of the SFrame. The seleted fields may be either a single image typed column, + or multiple numerical columns (int, float, array). + label_field : string (optional) + label field in SFrame + batch_size : int + batch size + + Examples + -------- + >>> import sframe as sf + >>> import mxnet as mx + + >>> data = sf.SFrame({'x': [1,2,3], 'y': [.1, .5, .5], 'z': [[1,1,1], [2,2,2,], [3,3,3]]}) + >>> dataiter = mx.io.SFrameIter(sframe=data, data_field=['x', 'z'], label_field='z') + + >>> image_data = sf.SFrame('http://s3.amazonaws.com/dato-datasets/mnist/sframe/train') + >>> image_data_iter = mx.io.SFrameIter(sframe=data, data_field=['image'], label_field='label', batch_size=100) + + Notes + ----- + - Image column must contain images of the same size. + - Array column must contain arrays of the same length. + """ + + def __init__(self, sframe, data_field, label_field=None, batch_size=1, data_name='data', label_name='softmax_label'): super(SFrameIter, self).__init__() if not isinstance(sframe, gl.SFrame): From 9ceaef93fccc6c6b0653245131327cfe45a7ddfc Mon Sep 17 00:00:00 2001 From: haijieg Date: Mon, 4 Apr 2016 13:25:08 -0700 Subject: [PATCH 8/8] improve image iter speed --- python/mxnet/io.py | 2 +- src/ndarray/ndarray.cc | 28 +++++++++++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index c07c7a2b8062..fdb32408d191 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -477,7 +477,7 @@ def _infer_column_shape(self, sarray): else: return (lengths.max(), ) elif dtype is gl.Image: - first_image = sarray.dropna()[0] + first_image = sarray.head(1)[0] return (first_image.channels, first_image.height, first_image.width) def infer_shape(self): diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 1a88d7528be8..23ca1cc8a674 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -584,7 +584,7 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz auto type = data[0].get_type(); if (type == graphlab::flex_type_enum::IMAGE) { CHECK_EQ(size, 1) << "Image data only support one input field"; - graphlab::image_type img = data[0].get(); + const graphlab::image_type& img = data[0].get(); mshadow::Tensor batch_tensor = dst.GetWithShape( mshadow::Shape4(dshape[0], img.m_channels, img.m_height, img.m_width)); @@ -604,16 +604,22 @@ void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t siz } else if (img.m_format == graphlab::Format::PNG) { graphlab::decode_png((const char*)img.get_image_data(), img.m_image_data_size, &buf, length); } - img.m_image_data.reset(buf); - img.m_image_data_size = length; - img.m_format = graphlab::Format::RAW_ARRAY; - } - size_t cnt = 0; - const unsigned char* raw_data = img.get_image_data(); - for (size_t i = 0; i < img.m_height; ++i) { - for (size_t j = 0; j < img.m_width; ++j) { - for (size_t k = 0; k < img.m_channels; ++k) { - batch_tensor[idx][k][i][j] = raw_data[cnt++]; + size_t cnt = 0; + for (size_t i = 0; i < img.m_height; ++i) { + for (size_t j = 0; j < img.m_width; ++j) { + for (size_t k = 0; k < img.m_channels; ++k) { + batch_tensor[idx][k][i][j] = buf[cnt++]; + } + } + } + } else { + size_t cnt = 0; + const unsigned char* raw_data = img.get_image_data(); + for (size_t i = 0; i < img.m_height; ++i) { + for (size_t j = 0; j < img.m_width; ++j) { + for (size_t k = 0; k < img.m_channels; ++k) { + batch_tensor[idx][k][i][j] = raw_data[cnt++]; + } } } }