diff --git a/qcodes/data/gnuplot_format.py b/qcodes/data/gnuplot_format.py index decaa7d3574..28052fd797c 100644 --- a/qcodes/data/gnuplot_format.py +++ b/qcodes/data/gnuplot_format.py @@ -3,7 +3,7 @@ import math import json -from qcodes.utils.helpers import deep_update +from qcodes.utils.helpers import deep_update, NumpyJSONEncoder from .data_array import DataArray from .format import Formatter @@ -297,7 +297,7 @@ def write_metadata(self, data_set, read_first=True): fn = io_manager.join(location, self.metadata_file) with io_manager.open(fn, 'w', encoding='utf8') as snap_file: json.dump(data_set.metadata, snap_file, sort_keys=True, - indent=4, ensure_ascii=False) + indent=4, ensure_ascii=False, cls=NumpyJSONEncoder) def read_metadata(self, data_set): io_manager = data_set.io diff --git a/qcodes/tests/test_json.py b/qcodes/tests/test_json.py new file mode 100644 index 00000000000..0c81e96ebc4 --- /dev/null +++ b/qcodes/tests/test_json.py @@ -0,0 +1,44 @@ +from unittest import TestCase +import numpy as np +import json + +from qcodes.utils.helpers import NumpyJSONEncoder + + +class TestNumpyJson(TestCase): + + def setUp(self): + self.metadata = { + 'name': 'Rapunzel', + 'age': np.int64(12), + 'height': np.float64(112.234), + 'scores': np.linspace(0, 42, num=3), + # include some regular values to ensure they work right + # with our encoder + 'weight': 19, + 'length': 45.23, + 'points': [12, 24, 48] + } + + def test_numpy_fail(self): + metadata = self.metadata + with self.assertRaises(TypeError): + json.dumps(metadata, sort_keys=True, indent=4, ensure_ascii=False) + + def test_numpy_good(self): + metadata = self.metadata + data = json.dumps(metadata, sort_keys=True, indent=4, + ensure_ascii=False, cls=NumpyJSONEncoder) + data_dict = json.loads(data) + + metadata = { + 'name': 'Rapunzel', + 'age': 12, + 'height': 112.234, + 'scores': [0, 21, 42], + 'weight': 19, + 'length': 45.23, + 'points': [12, 24, 48] + } + + self.assertEqual(metadata, data_dict) diff --git a/qcodes/utils/helpers.py b/qcodes/utils/helpers.py index b00bc368a99..343f2eec63e 100644 --- a/qcodes/utils/helpers.py +++ b/qcodes/utils/helpers.py @@ -6,10 +6,25 @@ import sys import io import numpy as np +import json _tprint_times = {} +class NumpyJSONEncoder(json.JSONEncoder): + """Return numpy types as standard types.""" + # http://stackoverflow.com/questions/27050108/convert-numpy-type-to-python + # http://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types/11389998#11389998 + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NumpyJSONEncoder, self).default(obj) + def tprint(string, dt=1, tag='default'): """ Print progress of a loop every dt seconds """ ptime = _tprint_times.get(tag, 0)