diff --git a/maro/rl/storage/column_based_store.py b/maro/rl/storage/column_based_store.py index e41c5f362..8b63015af 100644 --- a/maro/rl/storage/column_based_store.py +++ b/maro/rl/storage/column_based_store.py @@ -182,7 +182,7 @@ def sample_by_keys(self, keys: Sequence, sizes: Sequence, replace: bool = True): return indexes, self.get(indexes) def dumps(self): - return clone(self._store) + return clone(dict(self._store)) def get_by_key(self, key): return self._store[key] diff --git a/tests/test_store.py b/tests/test_store.py index deaa3f581..98684f460 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -33,7 +33,7 @@ def test_update(self): store = ColumnBasedStore(capacity=5, overwrite_type=OverwriteType.ROLLING) store.put({"a": [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10], "c": [11, 12, 13, 14, 15]}) store.update([0, 3], {"a": [-1, -4], "c": [-11, -14]}) - actual = store.take() + actual = store.dumps() expected = {"a": [-1, 2, 3, -4, 5], "b": [6, 7, 8, 9, 10], "c": [-11, 12, 13, -14, 15]} self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}") @@ -54,7 +54,7 @@ def test_put_with_rolling_overwrite(self): indexes = store.put({"a": [10, 11, 12, 13], "b": [14, 15, 16, 17], "c": [18, 19, 20, 21]}) expected = [-2, -1, 0, 1] self.assertEqual(indexes, expected, msg=f"expected indexes = {expected}, got {indexes}") - actual = store.take() + actual = store.dumps() expected = {"a": [12, 13, 3, 10, 11], "b": [16, 17, 6, 14, 15], "c": [20, 21, 9, 18, 19]} self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}")