Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add support of plug and play fit_batch and evaluate_batch (#16982)
Browse files Browse the repository at this point in the history
* Add support of plug and play fit_batch and evaluate_batch

* Add check for the validity of the estimator model

* Rename estimator model as batch processor

* Remove unused import

* Add documentation of the batch processor class

* refine the documentation of the batch processor

* Fix merge bugs

* fix bugs introduced during merge

* fix sanity check failures

* fix CI bugs
  • Loading branch information
liuzh47 authored and leezu committed Dec 11, 2019
1 parent 27389b1 commit c82af38
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 60 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/gluon/contrib/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
"""Gluon Estimator Module"""
from . import estimator
from . import event_handler
from . import batch_processor
from .estimator import *
from .event_handler import *
from .batch_processor import *
105 changes: 105 additions & 0 deletions python/mxnet/gluon/contrib/estimator/batch_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors
"""Gluon Batch Processor for Estimators"""

from ...utils import split_and_load
from .... import autograd

__all__ = ['BatchProcessor']

class BatchProcessor(object):
"""BatchProcessor Class for plug and play fit_batch & evaluate_batch
During training or validation, data are divided into minibatches for processing. This
class aims at providing hooks of training or validating on a minibatch of data. Users
may provide customized fit_batch() and evaluate_batch() methods by inheriting from
this class and overriding class methods.
:py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch()
in the base estimator class
"""

def __init__(self):
pass

def _get_data_and_label(self, batch, ctx, batch_axis=0):
data = batch[0]
label = batch[1]
data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis)
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis)
return data, label

def evaluate_batch(self, estimator,
val_batch,
batch_axis=0):
"""Evaluate the estimator model on a batch of validation data.
Parameters
----------
estimator : Estimator
Reference to the estimator
val_batch : tuple
Data and label of a batch from the validation data loader.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
pred = [estimator.eval_net(x) for x in data]
loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]

return data, label, pred, loss

def fit_batch(self, estimator,
train_batch,
batch_axis=0):
"""Trains the estimator model on a batch of training data.
Parameters
----------
estimator : Estimator
Reference to the estimator
train_batch : tuple
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.
Returns
-------
data: List of NDArray
Sharded data from the batch. Data is sharded with
`gluon.split_and_load`.
label: List of NDArray
Sharded label from the batch. Labels are sharded with
`gluon.split_and_load`.
pred: List of NDArray
Prediction on each of the sharded inputs.
loss: List of NDArray
Loss on each of the sharded inputs.
"""
data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis)

with autograd.record():
pred = [estimator.net(x) for x in data]
loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()

return data, label, pred, loss
84 changes: 24 additions & 60 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from ...loss import Loss as gluon_loss
from ...trainer import Trainer
from ...utils import split_and_load
from .... import autograd
from ....context import Context, cpu, gpu, num_gpus
from ....metric import Loss as metric_loss
from .batch_processor import BatchProcessor

__all__ = ['Estimator']

Expand Down Expand Up @@ -84,7 +84,8 @@ class Estimator(object):
the naming in mxnet Gluon API, please refer to the site
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
for future information.
batch_processor: BatchProcessor
BatchProcessor provides customized fit_batch() and evaluate_batch() methods
"""

logger = None
Expand Down Expand Up @@ -113,7 +114,8 @@ def __init__(self, net,
trainer=None,
context=None,
evaluation_loss=None,
eval_net=None):
eval_net=None,
batch_processor=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(train_metrics)
Expand All @@ -133,6 +135,7 @@ def __init__(self, net,
self.context = self._check_context(context)
self._initialize(initializer)
self.trainer = self._check_trainer(trainer)
self.batch_processor = self._check_batch_processor(batch_processor)

def _check_loss(self, loss):
if not isinstance(loss, gluon_loss):
Expand Down Expand Up @@ -173,6 +176,18 @@ def _check_context(self, context):
context = [cpu()]
return context

def _check_batch_processor(self, batch_processor):
# check whether the batch processor contains fit_batch() and evaluate_batch() methods
if batch_processor is not None:
model_fit = getattr(batch_processor, 'fit_batch', None)
model_evaluate = getattr(batch_processor, 'evaluate_batch', None)
if not callable(model_fit) or not callable(model_evaluate):
raise ValueError('Customized Batch Processor must contain fit_batch()'
' and evaluate_batch() methods')
else:
batch_processor = BatchProcessor()
return batch_processor

def _initialize(self, initializer):
# initialize the network
if not self._is_initialized():
Expand Down Expand Up @@ -254,24 +269,6 @@ def train_metrics(self):
def val_metrics(self):
return self._val_metrics

def evaluate_batch(self,
val_batch,
batch_axis=0):
"""Evaluate model on a batch of validation data.
Parameters
----------
val_batch : tuple
Data and label of a batch from the validation data loader.
batch_axis : int, default 0
Batch axis to split the validation data into devices.
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.eval_net(x) for x in data]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]

return data, label, pred, loss

def evaluate(self,
val_data,
batch_axis=0,
Expand Down Expand Up @@ -300,6 +297,7 @@ def evaluate(self,

for metric in self.val_metrics:
metric.reset()
estimator_ref = self

event_handlers = self._prepare_default_validation_handlers(event_handlers)

Expand All @@ -315,50 +313,16 @@ def evaluate(self,
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

_, label, pred, loss = self.evaluate_batch(batch, batch_axis)
_, label, pred, loss = \
self.batch_processor.evaluate_batch(estimator_ref, batch,
batch_axis)

for handler in batch_end:
handler.batch_end(estimator_ref, batch=batch, pred=pred, label=label, loss=loss)

for handler in epoch_end:
handler.epoch_end(estimator_ref)

def fit_batch(self, train_batch, batch_axis=0):
"""Trains the model on a batch of training data.
Parameters
----------
train_batch : tuple
Data and label of a batch from the training data loader.
batch_axis : int, default 0
Batch axis to split the training data into devices.
Returns
-------
data: List of NDArray
Sharded data from the batch. Data is sharded with
`gluon.split_and_load`.
label: List of NDArray
Sharded label from the batch. Labels are sharded with
`gluon.split_and_load`.
pred: List of NDArray
Prediction on each of the sharded inputs.
loss: List of NDArray
Loss on each of the sharded inputs.
"""
data, label = self._get_data_and_label(train_batch, self.context, batch_axis)

batch_size = train_batch[0].shape[batch_axis]

with autograd.record():
pred = [self.net(x) for x in data]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]

for l in loss:
l.backward()

return data, label, pred, loss

def fit(self, train_data,
val_data=None,
epochs=None,
Expand Down Expand Up @@ -432,8 +396,8 @@ def fit(self, train_data,
for handler in batch_begin:
handler.batch_begin(estimator_ref, batch=batch)

_, label, pred, loss = self.fit_batch(batch, batch_axis)

_, label, pred, loss = self.batch_processor.fit_batch(estimator_ref,
batch, batch_axis)
# batch end

batch_end_result = []
Expand Down
117 changes: 117 additions & 0 deletions tests/python/unittest/test_gluon_batch_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

''' Unit tests for Gluon Batch Processor '''

import sys
import unittest
import warnings

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.contrib.estimator import *
from mxnet.gluon.contrib.estimator.event_handler import *
from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor
from nose.tools import assert_raises

def _get_test_network():
net = nn.Sequential()
net.add(nn.Dense(4, activation='relu', flatten=False))
return net


def _get_test_data():
batch_size = 4
in_data = mx.nd.random.uniform(shape=(10, 3))
out_data = mx.nd.random.uniform(shape=(10, 4))
# Input dataloader
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
return dataloader, dataiter

def test_batch_processor_fit():
''' test estimator with different train data types '''
net = _get_test_network()
dataloader, dataiter = _get_test_data()
num_epochs = 1
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
net.initialize(ctx=ctx)
processor = BatchProcessor()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
train_metrics=acc,
trainer=trainer,
context=ctx,
batch_processor=processor)

est.fit(train_data=dataloader,
epochs=num_epochs)

with assert_raises(ValueError):
est.fit(train_data=dataiter,
epochs=num_epochs)

# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
epochs=num_epochs)


def test_batch_processor_validation():
''' test different validation data types'''
net = _get_test_network()
dataloader, dataiter = _get_test_data()
num_epochs = 1
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
evaluation_loss = gluon.loss.L1Loss()
net.initialize(ctx=ctx)
processor = BatchProcessor()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
train_metrics=acc,
trainer=trainer,
context=ctx,
evaluation_loss=evaluation_loss,
batch_processor=processor)
# Input dataloader
est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)

# using validation handler
train_metrics = est.train_metrics
val_metrics = est.val_metrics
validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)

with assert_raises(ValueError):
est.fit(train_data=dataiter,
val_data=dataiter,
epochs=num_epochs)
# Input NDArray
with assert_raises(ValueError):
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
val_data=[mx.nd.ones(shape=(10, 3))],
epochs=num_epochs)

0 comments on commit c82af38

Please sign in to comment.