This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support of plug and play fit_batch and evaluate_batch (#16982)
* Add support of plug and play fit_batch and evaluate_batch * Add check for the validity of the estimator model * Rename estimator model as batch processor * Remove unused import * Add documentation of the batch processor class * refine the documentation of the batch processor * Fix merge bugs * fix bugs introduced during merge * fix sanity check failures * fix CI bugs
- Loading branch information
Showing
4 changed files
with
248 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
python/mxnet/gluon/contrib/estimator/batch_processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable=wildcard-import, unused-argument, too-many-ancestors | ||
"""Gluon Batch Processor for Estimators""" | ||
|
||
from ...utils import split_and_load | ||
from .... import autograd | ||
|
||
__all__ = ['BatchProcessor'] | ||
|
||
class BatchProcessor(object): | ||
"""BatchProcessor Class for plug and play fit_batch & evaluate_batch | ||
During training or validation, data are divided into minibatches for processing. This | ||
class aims at providing hooks of training or validating on a minibatch of data. Users | ||
may provide customized fit_batch() and evaluate_batch() methods by inheriting from | ||
this class and overriding class methods. | ||
:py:class:`BatchProcessor` can be used to replace fit_batch() and evaluate_batch() | ||
in the base estimator class | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def _get_data_and_label(self, batch, ctx, batch_axis=0): | ||
data = batch[0] | ||
label = batch[1] | ||
data = split_and_load(data, ctx_list=ctx, batch_axis=batch_axis) | ||
label = split_and_load(label, ctx_list=ctx, batch_axis=batch_axis) | ||
return data, label | ||
|
||
def evaluate_batch(self, estimator, | ||
val_batch, | ||
batch_axis=0): | ||
"""Evaluate the estimator model on a batch of validation data. | ||
Parameters | ||
---------- | ||
estimator : Estimator | ||
Reference to the estimator | ||
val_batch : tuple | ||
Data and label of a batch from the validation data loader. | ||
batch_axis : int, default 0 | ||
Batch axis to split the validation data into devices. | ||
""" | ||
data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis) | ||
pred = [estimator.eval_net(x) for x in data] | ||
loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)] | ||
|
||
return data, label, pred, loss | ||
|
||
def fit_batch(self, estimator, | ||
train_batch, | ||
batch_axis=0): | ||
"""Trains the estimator model on a batch of training data. | ||
Parameters | ||
---------- | ||
estimator : Estimator | ||
Reference to the estimator | ||
train_batch : tuple | ||
Data and label of a batch from the training data loader. | ||
batch_axis : int, default 0 | ||
Batch axis to split the training data into devices. | ||
Returns | ||
------- | ||
data: List of NDArray | ||
Sharded data from the batch. Data is sharded with | ||
`gluon.split_and_load`. | ||
label: List of NDArray | ||
Sharded label from the batch. Labels are sharded with | ||
`gluon.split_and_load`. | ||
pred: List of NDArray | ||
Prediction on each of the sharded inputs. | ||
loss: List of NDArray | ||
Loss on each of the sharded inputs. | ||
""" | ||
data, label = self._get_data_and_label(train_batch, estimator.context, batch_axis) | ||
|
||
with autograd.record(): | ||
pred = [estimator.net(x) for x in data] | ||
loss = [estimator.loss(y_hat, y) for y_hat, y in zip(pred, label)] | ||
|
||
for l in loss: | ||
l.backward() | ||
|
||
return data, label, pred, loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
''' Unit tests for Gluon Batch Processor ''' | ||
|
||
import sys | ||
import unittest | ||
import warnings | ||
|
||
import mxnet as mx | ||
from mxnet import gluon | ||
from mxnet.gluon import nn | ||
from mxnet.gluon.contrib.estimator import * | ||
from mxnet.gluon.contrib.estimator.event_handler import * | ||
from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor | ||
from nose.tools import assert_raises | ||
|
||
def _get_test_network(): | ||
net = nn.Sequential() | ||
net.add(nn.Dense(4, activation='relu', flatten=False)) | ||
return net | ||
|
||
|
||
def _get_test_data(): | ||
batch_size = 4 | ||
in_data = mx.nd.random.uniform(shape=(10, 3)) | ||
out_data = mx.nd.random.uniform(shape=(10, 4)) | ||
# Input dataloader | ||
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) | ||
dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) | ||
dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) | ||
return dataloader, dataiter | ||
|
||
def test_batch_processor_fit(): | ||
''' test estimator with different train data types ''' | ||
net = _get_test_network() | ||
dataloader, dataiter = _get_test_data() | ||
num_epochs = 1 | ||
ctx = mx.cpu() | ||
loss = gluon.loss.L2Loss() | ||
acc = mx.metric.Accuracy() | ||
net.initialize(ctx=ctx) | ||
processor = BatchProcessor() | ||
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) | ||
est = Estimator(net=net, | ||
loss=loss, | ||
train_metrics=acc, | ||
trainer=trainer, | ||
context=ctx, | ||
batch_processor=processor) | ||
|
||
est.fit(train_data=dataloader, | ||
epochs=num_epochs) | ||
|
||
with assert_raises(ValueError): | ||
est.fit(train_data=dataiter, | ||
epochs=num_epochs) | ||
|
||
# Input NDArray | ||
with assert_raises(ValueError): | ||
est.fit(train_data=[mx.nd.ones(shape=(10, 3))], | ||
epochs=num_epochs) | ||
|
||
|
||
def test_batch_processor_validation(): | ||
''' test different validation data types''' | ||
net = _get_test_network() | ||
dataloader, dataiter = _get_test_data() | ||
num_epochs = 1 | ||
ctx = mx.cpu() | ||
loss = gluon.loss.L2Loss() | ||
acc = mx.metric.Accuracy() | ||
evaluation_loss = gluon.loss.L1Loss() | ||
net.initialize(ctx=ctx) | ||
processor = BatchProcessor() | ||
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) | ||
est = Estimator(net=net, | ||
loss=loss, | ||
train_metrics=acc, | ||
trainer=trainer, | ||
context=ctx, | ||
evaluation_loss=evaluation_loss, | ||
batch_processor=processor) | ||
# Input dataloader | ||
est.fit(train_data=dataloader, | ||
val_data=dataloader, | ||
epochs=num_epochs) | ||
|
||
# using validation handler | ||
train_metrics = est.train_metrics | ||
val_metrics = est.val_metrics | ||
validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate) | ||
|
||
with assert_raises(ValueError): | ||
est.fit(train_data=dataiter, | ||
val_data=dataiter, | ||
epochs=num_epochs) | ||
# Input NDArray | ||
with assert_raises(ValueError): | ||
est.fit(train_data=[mx.nd.ones(shape=(10, 3))], | ||
val_data=[mx.nd.ones(shape=(10, 3))], | ||
epochs=num_epochs) | ||
|