Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Check CIFTI-2 data shape matches shape described by header #774

Merged
merged 9 commits into from
Jul 13, 2019
Merged
2 changes: 1 addition & 1 deletion nibabel/batteryrunners.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def check_only(self, obj):
-------
reports : sequence
sequence of report objects reporting on result of running
checks (withou fixes) on `obj`
checks (without fixes) on `obj`
'''
reports = []
for check in self._checks:
Expand Down
49 changes: 46 additions & 3 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..keywordonly import kw_only_meth
from warnings import warn


def _float_01(val):
Expand Down Expand Up @@ -1209,6 +1210,38 @@ def _to_xml_element(self):
mat.append(mim._to_xml_element())
return mat

def get_axis(self, index):
'''
Generates the Cifti2 axis for a given dimension

Parameters
----------
index : int
Dimension for which we want to obtain the mapping.

Returns
-------
axis : :class:`.cifti2_axes.Axis`
'''
from . import cifti2_axes
return cifti2_axes.from_index_mapping(self.get_index_map(index))

def get_data_shape(self):
"""
Returns data shape expected based on the CIFTI-2 header

Any dimensions omitted in the CIFTI-2 header will be given a default size of None.
"""
from . import cifti2_axes
if len(self.mapped_indices) == 0:
return ()
base_shape = [None] * (max(self.mapped_indices) + 1)
for mim in self:
size = len(cifti2_axes.from_index_mapping(mim))
for idx in mim.applies_to_matrix_dimension:
base_shape[idx] = size
return tuple(base_shape)


class Cifti2Header(FileBasedHeader, xml.XmlSerializable):
''' Class for CIFTI-2 header extension '''
Expand Down Expand Up @@ -1279,8 +1312,7 @@ def get_axis(self, index):
-------
axis : :class:`.cifti2_axes.Axis`
'''
from . import cifti2_axes
return cifti2_axes.from_index_mapping(self.matrix.get_index_map(index))
return self.matrix.get_axis(index)

@classmethod
def from_axes(cls, axes):
Expand Down Expand Up @@ -1345,12 +1377,18 @@ def __init__(self,
super(Cifti2Image, self).__init__(dataobj, header=header,
extra=extra, file_map=file_map)
self._nifti_header = Nifti2Header.from_header(nifti_header)

# if NIfTI header not specified, get data type from input array
if nifti_header is None:
if hasattr(dataobj, 'dtype'):
self._nifti_header.set_data_dtype(dataobj.dtype)
self.update_headers()

if self._dataobj.shape != self.header.matrix.get_data_shape():
warn("Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
self._dataobj.shape, self.header.matrix.get_data_shape()
))

@property
def nifti_header(self):
return self._nifti_header
Expand Down Expand Up @@ -1426,6 +1464,11 @@ def to_file_map(self, file_map=None):
header = self._nifti_header
extension = Cifti2Extension(content=self.header.to_xml())
header.extensions.append(extension)
if self._dataobj.shape != self.header.matrix.get_data_shape():
raise ValueError(
"Dataobj shape {} does not match shape expected from CIFTI-2 header {}".format(
self._dataobj.shape, self.header.matrix.get_data_shape()
))
# if intent code is not set, default to unknown CIFTI
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
Expand All @@ -1438,7 +1481,7 @@ def to_file_map(self, file_map=None):
img.to_file_map(file_map or self.file_map)

def update_headers(self):
''' Harmonize CIFTI-2 and NIfTI headers with image data
''' Harmonize NIfTI headers with image data

>>> import numpy as np
>>> data = np.zeros((2,3,4))
Expand Down
6 changes: 6 additions & 0 deletions nibabel/cifti2/tests/test_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,4 +358,10 @@ class TestCifti2ImageAPI(_TDA):
standard_extension = '.nii'

def make_imaker(self, arr, header=None, ni_header=None):
for idx, sz in enumerate(arr.shape):
maps = [ci.Cifti2NamedMap(str(value)) for value in range(sz)]
mim = ci.Cifti2MatrixIndicesMap(
(idx, ), 'CIFTI_INDEX_TYPE_SCALARS', maps=maps
)
header.matrix.append(mim)
return lambda: self.image_maker(arr.copy(), header, ni_header)
56 changes: 42 additions & 14 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from nibabel import cifti2 as ci
from nibabel.tmpdirs import InTemporaryDirectory

from nose.tools import assert_true, assert_equal
from nose.tools import assert_true, assert_equal, assert_raises
from nibabel.testing import clear_and_catch_warnings, error_warnings, suppress_warnings

affine = [[-1.5, 0, 0, 90],
[0, 1.5, 0, -85],
[0, 0, 1.5, -71]]
[0, 0, 1.5, -71],
[0, 0, 0, 1.]]

dimensions = (120, 83, 78)

Expand Down Expand Up @@ -234,7 +236,7 @@ def test_dtseries():
matrix.append(series_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 9)
data = np.random.randn(13, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES')

Expand All @@ -257,7 +259,7 @@ def test_dscalar():
matrix.append(scalar_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 9)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS')

Expand All @@ -279,7 +281,7 @@ def test_dlabel():
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 9)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS')

Expand All @@ -299,7 +301,7 @@ def test_dconn():
matrix = ci.Cifti2Matrix()
matrix.append(mapping)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(9, 9)
data = np.random.randn(10, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE')

Expand All @@ -322,7 +324,7 @@ def test_ptseries():
matrix.append(series_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(13, 3)
data = np.random.randn(13, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES')

Expand All @@ -344,7 +346,7 @@ def test_pscalar():
matrix.append(scalar_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR')

Expand All @@ -366,7 +368,7 @@ def test_pdconn():
matrix.append(geometry_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(10, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE')

Expand All @@ -388,7 +390,7 @@ def test_dpconn():
matrix.append(parcel_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(4, 10)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED')

Expand All @@ -410,7 +412,7 @@ def test_plabel():
matrix.append(label_map)
matrix.append(parcel_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 3)
data = np.random.randn(2, 4)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
Expand All @@ -429,7 +431,7 @@ def test_pconn():
matrix = ci.Cifti2Matrix()
matrix.append(mapping)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3)
data = np.random.randn(4, 4)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED')

Expand All @@ -453,7 +455,7 @@ def test_pconnseries():
matrix.append(parcel_map)
matrix.append(series_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3, 13)
data = np.random.randn(4, 4, 13)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SERIES')
Expand All @@ -479,7 +481,7 @@ def test_pconnscalar():
matrix.append(parcel_map)
matrix.append(scalar_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(3, 3, 13)
data = np.random.randn(4, 4, 2)
img = ci.Cifti2Image(data, hdr)
img.nifti_header.set_intent('NIFTI_INTENT_CONNECTIVITY_PARCELLATED_'
'PARCELLATED_SCALAR')
Expand All @@ -496,3 +498,29 @@ def test_pconnscalar():
check_parcel_map(img2.header.matrix.get_index_map(0))
check_scalar_map(img2.header.matrix.get_index_map(2))
del img2


def test_wrong_shape():
scalar_map = create_scalar_map((0, ))
brain_model_map = create_geometry_map((1, ))

matrix = ci.Cifti2Matrix()
matrix.append(scalar_map)
matrix.append(brain_model_map)
hdr = ci.Cifti2Header(matrix)

# correct shape is (2, 10)
for data in (
np.random.randn(1, 11),
np.random.randn(2, 10, 1),
np.random.randn(1, 2, 10),
np.random.randn(3, 10),
np.random.randn(2, 9),
):
with clear_and_catch_warnings():
with error_warnings():
assert_raises(UserWarning, ci.Cifti2Image, data, hdr)
with suppress_warnings():
img = ci.Cifti2Image(data, hdr)
assert_raises(ValueError, img.to_file_map)