From 92cc356c7f17ac031f6cd0a8330145d2b28d41d1 Mon Sep 17 00:00:00 2001 From: Nadia Dencheva Date: Tue, 20 Sep 2016 17:41:31 -0400 Subject: [PATCH] Add Tabular Model. add another test and fix comparison Add a changelog entry [skip ci] fix a changelog entry [skip ci] --- CHANGES.rst | 7 ++- asdf/tags/transform/__init__.py | 1 + asdf/tags/transform/tabular.py | 69 +++++++++++++++++++++ asdf/tags/transform/tests/test_transform.py | 19 ++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 asdf/tags/transform/tabular.py diff --git a/CHANGES.rst b/CHANGES.rst index ec395adbe..9a7e5ee3a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,12 @@ +1.1.1(Unreleased) +----------------- + +- Added Tabular model. [#214] + 1.0.5 (2016-06-28) ------------------ -- Fixed a memory leak when reading wcs that grew memory to over 10 Gb +- Fixed a memory leak when reading wcs that grew memory to over 10 Gb. [#200] 1.0.4 (2016-05-25) ------------------ diff --git a/asdf/tags/transform/__init__.py b/asdf/tags/transform/__init__.py index 61fc9f56d..891325435 100644 --- a/asdf/tags/transform/__init__.py +++ b/asdf/tags/transform/__init__.py @@ -7,3 +7,4 @@ from .compound import * from .projections import * from .polynomial import * +from .tabular import * diff --git a/asdf/tags/transform/tabular.py b/asdf/tags/transform/tabular.py new file mode 100644 index 000000000..4cba77217 --- /dev/null +++ b/asdf/tags/transform/tabular.py @@ -0,0 +1,69 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, unicode_literals, print_function + +import numpy as np +from numpy.testing import assert_array_equal +from ... import yamlutil + +from .basic import TransformType + +__all__ = ['TabularType'] + + +class TabularType(TransformType): + import astropy + name = "transform/tabular" + types = [astropy.modeling.models.Tabular2D, + astropy.modeling.models.Tabular1D + ] + + @classmethod + def from_tree_transform(cls, node, ctx): + from astropy import modeling + lookup_table = node.pop("lookup_table") + dim = lookup_table.ndim + name = node.get('name', None) + fill_value = node.pop("fill_value", None) + points = np.asarray(node['points']) + if dim == 1: + model = modeling.models.Tabular1D(points=points, lookup_table=lookup_table, + method=node['method'], bounds_error=node['bounds_error'], + fill_value=fill_value, name=name) + elif dim == 2: + model = modeling.models.Tabular2D(points=points, lookup_table=lookup_table, + method=node['method'], bounds_error=node['bounds_error'], + fill_value=fill_value, name=name) + else: + tabular_class = modeling.models.tabular_model(dim, name) + + model = tabular_class(points=points, lookup_table=lookup_table, + method=node['method'], bounds_error=node['bounds_error'], + fill_value=fill_value, name=name) + + return model + + @classmethod + def to_tree_transform(cls, model, ctx): + node = {} + node["fill_value"] = model.fill_value + node["lookup_table"] = model.lookup_table + node["points"] = model.points + node["method"] = str(model.method) + node["bounds_error"] = model.bounds_error + node["name"] = model.name + return yamlutil.custom_tree_to_tagged_tree(node, ctx) + + @classmethod + def assert_equal(cls, a, b): + assert_array_equal(a.lookup_table, b.lookup_table) + assert_array_equal(a.points, b.points) + assert (a.method == b.method) + if a.fill_value is None: + assert b.fill_value is None + elif np.isnan(a.fill_value): + assert np.isnan(b.fill_value) + else: + assert(a.fill_value == b.fill_value) + assert(a.bounds_error == b.bounds_error) diff --git a/asdf/tags/transform/tests/test_transform.py b/asdf/tags/transform/tests/test_transform.py index 7b8a5870a..da7ea580e 100644 --- a/asdf/tags/transform/tests/test_transform.py +++ b/asdf/tags/transform/tests/test_transform.py @@ -3,6 +3,8 @@ from __future__ import absolute_import, division, unicode_literals, print_function +import numpy as np + try: import astropy except ImportError: @@ -126,3 +128,20 @@ def test_generic_projections(tmpdir): } helpers.assert_roundtrip_tree(tree, tmpdir) + + +@pytest.mark.skipif('not HAS_ASTROPY') +def test_tabular_model(tmpdir): + points = np.arange(0, 5) + values = [1., 10, 2, 45, -3] + model = astmodels.Tabular1D(points=points, lookup_table=values) + tree = {'model': model} + helpers.assert_roundtrip_tree(tree, tmpdir) + table = np.array([[ 3., 0., 0.], + [ 0., 2., 0.], + [ 0., 0., 0.]]) + points = ([1, 2, 3], [1, 2, 3]) + model2 = astmodels.Tabular2D(points, lookup_table=table, bounds_error=False, + fill_value=None, method='nearest') + tree = {'model': model2} + helpers.assert_roundtrip_tree(tree, tmpdir)