From df582d235a6e6c8e114053015a7b7392bee8f570 Mon Sep 17 00:00:00 2001 From: Chris Holden Date: Sun, 6 Dec 2015 17:03:02 -0500 Subject: [PATCH] Add regression diagnotic calcs #70 --- tests/regression/conftest.py | 6 +++++ .../regression/test_regression_diagnostics.py | 18 +++++++++++++++ yatsm/regression/diagnostics.py | 22 +++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 tests/regression/test_regression_diagnostics.py create mode 100644 yatsm/regression/diagnostics.py diff --git a/tests/regression/conftest.py b/tests/regression/conftest.py index b33362ab..e4de9069 100644 --- a/tests/regression/conftest.py +++ b/tests/regression/conftest.py @@ -1,11 +1,17 @@ import os +import numpy as np import pandas as pd import pytest here = os.path.dirname(__file__) +@pytest.fixture(scope='function') +def prng(): + return np.random.RandomState(123456789) + + @pytest.fixture(scope='function') def airquality(request): airquality = pd.read_csv(os.path.join(here, 'data', 'airquality.csv')) diff --git a/tests/regression/test_regression_diagnostics.py b/tests/regression/test_regression_diagnostics.py new file mode 100644 index 00000000..efe5c3f2 --- /dev/null +++ b/tests/regression/test_regression_diagnostics.py @@ -0,0 +1,18 @@ +import numpy as np +import pytest + +from yatsm.regression.diagnostics import rmse + +n = 500 + + +@pytest.fixture +def y(prng): + y = prng.standard_normal(n) + return y + + +def test_rmse(y): + yhat = y + 1.0 + _rmse = rmse(y, yhat) + np.testing.assert_allclose(_rmse, 1.0) diff --git a/yatsm/regression/diagnostics.py b/yatsm/regression/diagnostics.py new file mode 100644 index 00000000..85499ce3 --- /dev/null +++ b/yatsm/regression/diagnostics.py @@ -0,0 +1,22 @@ +""" Regression diagnostics calculations + +Includes: + - rmse: calculate root mean squared error +""" +import numpy as np + +from ..accel import try_jit + + +@try_jit(nopython=True) +def rmse(y, yhat): + """ Calculate and return Root Mean Squared Error (RMSE) + + Args: + y (np.ndarray): known values + yhat (np.ndarray): predicted values + + Returns: + float: Root Mean Squared Error + """ + return ((y - yhat) ** 2).mean() ** 0.5