Skip to content

Commit

Permalink
Add Tabular Model.
Browse files Browse the repository at this point in the history
add another test and fix comparison

Add a changelog entry [skip ci]

fix a changelog entry [skip ci]
  • Loading branch information
nden committed Sep 21, 2016
1 parent 50546cf commit 92cc356
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 1 deletion.
7 changes: 6 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
1.1.1(Unreleased)
-----------------

- Added Tabular model. [#214]

1.0.5 (2016-06-28)
------------------

- Fixed a memory leak when reading wcs that grew memory to over 10 Gb
- Fixed a memory leak when reading wcs that grew memory to over 10 Gb. [#200]

1.0.4 (2016-05-25)
------------------
Expand Down
1 change: 1 addition & 0 deletions asdf/tags/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .compound import *
from .projections import *
from .polynomial import *
from .tabular import *
69 changes: 69 additions & 0 deletions asdf/tags/transform/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, unicode_literals, print_function

import numpy as np
from numpy.testing import assert_array_equal
from ... import yamlutil

from .basic import TransformType

__all__ = ['TabularType']


class TabularType(TransformType):
import astropy
name = "transform/tabular"
types = [astropy.modeling.models.Tabular2D,
astropy.modeling.models.Tabular1D
]

@classmethod
def from_tree_transform(cls, node, ctx):
from astropy import modeling
lookup_table = node.pop("lookup_table")
dim = lookup_table.ndim
name = node.get('name', None)
fill_value = node.pop("fill_value", None)
points = np.asarray(node['points'])
if dim == 1:
model = modeling.models.Tabular1D(points=points, lookup_table=lookup_table,
method=node['method'], bounds_error=node['bounds_error'],
fill_value=fill_value, name=name)
elif dim == 2:
model = modeling.models.Tabular2D(points=points, lookup_table=lookup_table,
method=node['method'], bounds_error=node['bounds_error'],
fill_value=fill_value, name=name)
else:
tabular_class = modeling.models.tabular_model(dim, name)

model = tabular_class(points=points, lookup_table=lookup_table,
method=node['method'], bounds_error=node['bounds_error'],
fill_value=fill_value, name=name)

return model

@classmethod
def to_tree_transform(cls, model, ctx):
node = {}
node["fill_value"] = model.fill_value
node["lookup_table"] = model.lookup_table
node["points"] = model.points
node["method"] = str(model.method)
node["bounds_error"] = model.bounds_error
node["name"] = model.name
return yamlutil.custom_tree_to_tagged_tree(node, ctx)

@classmethod
def assert_equal(cls, a, b):
assert_array_equal(a.lookup_table, b.lookup_table)
assert_array_equal(a.points, b.points)
assert (a.method == b.method)
if a.fill_value is None:
assert b.fill_value is None
elif np.isnan(a.fill_value):
assert np.isnan(b.fill_value)
else:
assert(a.fill_value == b.fill_value)
assert(a.bounds_error == b.bounds_error)
19 changes: 19 additions & 0 deletions asdf/tags/transform/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import absolute_import, division, unicode_literals, print_function

import numpy as np

try:
import astropy
except ImportError:
Expand Down Expand Up @@ -126,3 +128,20 @@ def test_generic_projections(tmpdir):
}

helpers.assert_roundtrip_tree(tree, tmpdir)


@pytest.mark.skipif('not HAS_ASTROPY')
def test_tabular_model(tmpdir):
points = np.arange(0, 5)
values = [1., 10, 2, 45, -3]
model = astmodels.Tabular1D(points=points, lookup_table=values)
tree = {'model': model}
helpers.assert_roundtrip_tree(tree, tmpdir)
table = np.array([[ 3., 0., 0.],
[ 0., 2., 0.],
[ 0., 0., 0.]])
points = ([1, 2, 3], [1, 2, 3])
model2 = astmodels.Tabular2D(points, lookup_table=table, bounds_error=False,
fill_value=None, method='nearest')
tree = {'model': model2}
helpers.assert_roundtrip_tree(tree, tmpdir)

0 comments on commit 92cc356

Please sign in to comment.