Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[PY_METRIC] Make it consistent with numpy sklearn style, support nump… #165

Merged
merged 2 commits into from
Sep 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# pylint: disable=invalid-name
"""Online evaluation metric module."""
from .base import string_types
import numpy as np
import numpy

class EvalMetric(object):
"""Base class of all evaluation metrics."""
def __init__(self, name):
self.name = name
self.reset()

def update(self, pred, label):
def update(self, label, pred):
"""Update the internal evaluation.

Parameters
----------
pred : NDArray
Predicted value.

label : NDArray
The label of the data.

pred : NDArray
Predicted value.
"""
raise NotImplementedError()

Expand All @@ -45,28 +45,56 @@ class Accuracy(EvalMetric):
def __init__(self):
super(Accuracy, self).__init__('accuracy')

def update(self, pred, label):
def update(self, label, pred):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
py = np.argmax(pred, axis=1)
self.sum_metric += np.sum(py == label)
py = numpy.argmax(pred, axis=1)
self.sum_metric += numpy.sum(py == label)
self.num_inst += label.size


class CustomMetric(EvalMetric):
"""Calculate accuracy"""
def __init__(self, feval):
name = feval.__name__
if name.find('<') != -1:
name = 'custom(%s)' % name
"""Custom evaluation metric that takes a NDArray function.

Parameters
----------
feval : callable(label, pred)
Customized evaluation function.

name : str, optional
The name of the metric
"""
def __init__(self, feval, name=None):
if name is None:
name = feval.__name__
if name.find('<') != -1:
name = 'custom(%s)' % name
super(CustomMetric, self).__init__(name)
self._feval = feval

def update(self, pred, label):
self.sum_metric += self._feval(pred, label)
def update(self, label, pred):
self.sum_metric += self._feval(label, pred)
self.num_inst += 1


def np(numpy_feval, name=None):
"""Create a customized metric from numpy function.

Parameters
----------
numpy_feval : callable(label, pred)
Customized evaluation function.

name : str, optional
The name of the metric.
"""
def feval(label, pred):
"""Internal eval function."""
return numpy_feval(label.asnumpy(), pred.asnumpy())
feval.__name__ = numpy_feval.__name__
return CustomMetric(feval, name)


def create(metric):
"""Create an evaluation metric.

Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _train_multi_device(symbol, ctx, input_shape,
else:
epoch_end_callback(nbatch)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)
eval_metric.update(label, out_cpu_array)
# reset training data after iteration finish
train_data.reset()
name, value = eval_metric.get()
Expand All @@ -294,7 +294,7 @@ def _train_multi_device(symbol, ctx, input_shape,
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=False)
texec.outputs[0].copyto(out_cpu_array[islice])
eval_metric.update(out_cpu_array, label)
eval_metric.update(label, out_cpu_array)
eval_data.reset()
name, value = eval_metric.get()
logger.info('Iteration[%d] Validation-%s=%f', iteration, name, value)
Expand Down
6 changes: 2 additions & 4 deletions tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.Softmax(fc3, name = 'sm')

def accuracy(pred, label):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
def accuracy(label, pred):
py = np.argmax(pred, axis=1)
return np.sum(py == label) / float(label.size)

Expand Down Expand Up @@ -50,7 +48,7 @@ def test_mlp():
softmax,
X=train_dataiter,
eval_data=val_dataiter,
eval_metric=accuracy,
eval_metric=mx.metric.np(accuracy),
iter_end_callback=mx.callback.do_checkpoint(prefix),
ctx=[mx.cpu(i) for i in range(2)],
num_round=num_round,
Expand Down