diff --git a/docs/containers.rst b/docs/containers.rst index 117c9126..42cd29d8 100644 --- a/docs/containers.rst +++ b/docs/containers.rst @@ -64,6 +64,7 @@ Supporting Classes BranchContainer DepletedMaterial HomogUniv + UnivTuple XSData .. _api-xs: diff --git a/serpentTools/objects/containers.py b/serpentTools/objects/containers.py index f4c7ed92..b64753a8 100644 --- a/serpentTools/objects/containers.py +++ b/serpentTools/objects/containers.py @@ -711,7 +711,7 @@ def __setitem__(self, key, value): raise TypeError("{} {}".format(key, type(value))) if not isinstance(key, UnivTuple): key = UnivTuple(*key) - dict.__setitem__(self, key, value) + super().__setitem__(key, value) def update(self, other): """Update with contents of another BranchContainer""" @@ -722,7 +722,7 @@ def update(self, other): raise TypeError("{} {}".format(key, type(univ))) temp[UnivTuple(*key)] = univ - dict.update(self, temp) + super().update(temp) def getUniv(self, univID, burnup=None, index=None, days=None): """ diff --git a/serpentTools/parsers/branching.py b/serpentTools/parsers/branching.py index e7452437..b970506d 100644 --- a/serpentTools/parsers/branching.py +++ b/serpentTools/parsers/branching.py @@ -5,7 +5,7 @@ from serpentTools.utils import splitValsUncs from serpentTools.objects import BranchContainer, UnivTuple, HomogUniv from serpentTools.parsers.base import XSReader -from serpentTools.messages import debug, error +from serpentTools.messages import debug, error, deprecated class BranchingReader(XSReader): @@ -14,12 +14,12 @@ class BranchingReader(XSReader): Parameters ---------- - filePath: str + filePath : str path to the depletion file Attributes ---------- - branches: dict + branches : dict Dictionary of branch names and their corresponding :class:`~serpentTools.objects.BranchContainer` objects """ @@ -38,6 +38,46 @@ def hasUncs(self): """boolean if uncertainties are present in the file""" return self._hasUncs + def __len__(self): + """Number of branches stored on the reader""" + return len(self.branches) + + def __contains__(self, key): + """Check if a branch is stored on the reader + + Parameters + ---------- + key : str or iterable of str + Name of the branch as defined in the Serpent input file + + Returns + ------- + bool + Flag indicating the presence of ``key`` + + """ + return key in self.branches + + def __getitem__(self, key): + """Return a specific branch from :attr:`branches` + + Parameters + ---------- + key : str or iterable of str + Branch name as defined in Serpent input + + Returns + ------- + serpentTools.objects.BranchContainer + Branch corresponding to ``key`` + + """ + return self.branches[key] + + def __iter__(self): + """Iterate over all branch names""" + return iter(self.branches) + def _read(self): """Read the branching file and store the coefficients.""" with open(self.filePath) as fObj: @@ -131,10 +171,39 @@ def _processBranchUniverses(self, branch, burnup, burnupIndex): else: univ.addData(varName, array(varValues), uncertainty=False) + def get(self, key, default=None): + """Return a branch that may or may not exist in :attr:`branches` + + Parameters + ---------- + key : str or iterable of str + Branch name as defined in Serpent input + default : object, optional + Item to return if ``key`` is not found + + Returns + ------- + object + :class:`~serpentTools.objects.BranchContainer` if + ``key`` is found. ``default`` if not + + """ + return self.branches.get(key, default) + + def items(self): + """Iterate over key, branch pairs from :attr:`branches`""" + return self.branches.items() + + @deprecated("items") def iterBranches(self): - """Iterate over branches yielding paired branch IDs and containers""" - for bID, b in self.branches.items(): - yield bID, b + """Iterate over branches yielding paired branch IDs and containers + + .. deprecated:: 0.9.3 + + Use :meth:`items` instead + + """ + return self.items() def _precheck(self): """Total number of branches and check for uncertainties""" diff --git a/serpentTools/parsers/history.py b/serpentTools/parsers/history.py index 0fed49da..84856da6 100644 --- a/serpentTools/parsers/history.py +++ b/serpentTools/parsers/history.py @@ -83,6 +83,22 @@ def __init__(self, filePath): self.arrays = {} self.numInactive = None + def __getitem__(self, key): + """Return an item from :attr:`arrays`""" + return self.arrays[key] + + def __contains__(self, key): + """Return ``True`` if key is in :attr:`arrays`, otherwise ``False``""" + return key in self.arrays + + def __len__(self): + """Return number of entries in :attr:`arrays`.""" + return len(self.arrays) + + def __iter__(self): + """Iterate over keys in :attr:`arrays`""" + return iter(self.arrays) + def _precheck(self): with open(self.filePath) as check: for line in check: @@ -97,8 +113,9 @@ def _postcheck(self): if self.numInactive is None: error('Unable to acertain the number of inactive cycles') - def __getitem__(self, key): - return self.arrays[key] + def get(self, key, default=None): + """Return an array or default if not found""" + return self.arrays.get(key, default) def _read(self): curKey = None @@ -106,7 +123,7 @@ def _read(self): cycles = None indx = 0 with open(self.filePath) as out: - for lineNo, line in enumerate(out): + for line in out: if not line.strip(): continue if '=' in line: @@ -140,22 +157,9 @@ def _gather_matlab(self, reconvert): out[converter(key)] = value return out - def __contains__(self, key): - """Return ``True`` if key is in :attr:`arrays`, otherwise ``False``""" - return key in self.arrays - - def __len__(self): - """Return number of entries in :attr:`arrays`.""" - return len(self.arrays) - def items(self): """Iterate over ``(key, value)`` pairs from :attr:`arrays`""" - for key, value in self.arrays.items(): - yield key, value - - def __iter__(self): - """Iterate over keys in :attr:`arrays`""" - return self.arrays.__iter__() + return self.arrays.items() @staticmethod def ioConvertName(name): diff --git a/tests/test_branching.py b/tests/test_branching.py index 22ea4aee..1aa6932c 100644 --- a/tests/test_branching.py +++ b/tests/test_branching.py @@ -63,11 +63,31 @@ def test_raiseError(self): def test_branchingUniverses(self): """Verify that the correct universes are present.""" - for branchID, branch in self.reader.iterBranches(): + for branchID, branch in self.reader.items(): self.assertSetEqual( self.expectedUniverses, set(branch), 'Branch {}'.format(branchID)) + def test_special(self): + """Test special methods like len, get""" + self.assertEqual(len(self.reader), len(self.reader.branches)) + allBranches = set(self.reader.branches) + self.assertSetEqual(set(self.reader), allBranches) + + for key, branch in self.reader.items(): + self.assertIn(key, self.reader, msg=key) + self.assertIs(self.reader[key], branch, msg=key) + self.assertIs(self.reader.get(key), branch, msg=key) + allBranches.remove(key) + + self.assertSetEqual( + allBranches, set(), msg="Did not iterate over all items") + + with self.assertRaises(KeyError): + self.reader["this should not exist"] + + self.assertIs(self.reader.get("this should not exist"), None) + class BranchContainerTester(_BranchTesterHelper): """Class to test the branch container""" diff --git a/tests/test_history.py b/tests/test_history.py index 12c3cd6b..75117636 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -264,11 +264,6 @@ def test_sizes(self): self.assertTupleEqual(shape, self.arrays[key].shape, msg=key) - def test_getItem(self): - """Verify the getitem indexing is functional.""" - for key, readerArray in self.arrays.items(): - self.assertIs(readerArray, self.reader[key], msg=key) - def test_arrayHeads(self): """Verify the first few lines of each array are correct.""" for key, expectedArray in EXPECTED_ARRAY_HEADS.items(): @@ -285,11 +280,16 @@ def test_arrayTails(self): def test_specialMethods(self): """Test special methods on the reader""" + for key, readerArray in self.arrays.items(): + self.assertIs(self.reader[key], readerArray, msg=key) # test len self.assertEqual(len(self.reader), len(self.reader.arrays)) # test contains badKey = 'this_shouldNotBe_present' - self.assertFalse(badKey in self.reader) + self.assertNotIn(badKey, self.reader) + with self.assertRaises(KeyError): + self.reader[badKey] + self.assertIs(self.reader.get(badKey), None) # test iter for nFound, key in enumerate(self.reader, start=1): self.assertTrue(key in self.reader.arrays, msg=key) @@ -299,10 +299,11 @@ def test_specialMethods(self): def test_iterItems(self): """Test the items method for yielding key, value pairs""" for nFound, (key, value) in enumerate(self.reader.items(), start=1): - self.assertTrue(key in self.reader.arrays, msg=key) - self.assertTrue(key in self.reader, msg=key) - assert_array_equal(value, self.reader.arrays[key], err_msg=key) - self.assertTrue(value is self.reader[key], msg=key) + self.assertIn(key, self.reader.arrays, msg=key) + self.assertIn(key, self.reader, msg=key) + self.assertIs(value, self.reader.arrays[key], msg=key) + self.assertIs(value, self.reader[key], msg=key) + self.assertIs(value, self.reader.get(key), msg=key) self.assertEqual(nFound, len(self.reader.arrays))