From 0b13f87c9a07421c2da39609b29743b46daaf6bb Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 19 Oct 2023 14:18:43 +0200 Subject: [PATCH 1/2] deterministic set hash --- src/datasets/utils/py_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 912824fa7f3..243bc0f99c9 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -736,6 +736,29 @@ def proxy(func): return proxy +if config.DILL_VERSION < version.parse("0.3.6"): + + @pklregister(set) + def _save_set(pickler, obj): + dill._dill.log.info(f"Se: {obj}") + from datasets.fingerprint import Hasher + + args = (sorted(obj, key=Hasher.hash),) + pickler.save_reduce(set, args, obj=obj) + dill._dill.log.info("# Se") + +elif config.DILL_VERSION.release[:3] in [version.parse("0.3.6").release, version.parse("0.3.7").release]: + + @pklregister(set) + def _save_set(pickler, obj): + dill._dill.logger.trace(pickler, "Se: %s", obj) + from datasets.fingerprint import Hasher + + args = (sorted(obj, key=Hasher.hash),) + pickler.save_reduce(set, args, obj=obj) + dill._dill.logger.trace(pickler, "# Se") + + if config.DILL_VERSION < version.parse("0.3.6"): @pklregister(CodeType) From 7f1a7d621fff3b08ace02643466097654a5e010f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 19 Oct 2023 14:18:46 +0200 Subject: [PATCH 2/2] tests --- tests/test_fingerprint.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 5538f27554e..f4d5d65744e 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -2,6 +2,7 @@ import os import pickle import subprocess +from functools import partial from hashlib import md5 from pathlib import Path from tempfile import gettempdir @@ -10,6 +11,7 @@ from unittest import TestCase from unittest.mock import patch +import numpy as np import pytest from multiprocess import Pool @@ -254,6 +256,22 @@ def test_hash_same_strings(self): self.assertEqual(hash1, hash2) self.assertEqual(hash1, hash3) + def test_set_stable(self): + rng = np.random.default_rng(42) + set_ = {rng.random() for _ in range(10_000)} + expected_hash = Hasher.hash(set_) + assert expected_hash == Pool(1).apply_async(partial(Hasher.hash, set(set_))).get() + + def test_set_doesnt_depend_on_order(self): + set_ = set("abc") + hash1 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + set_ = set("def") + hash2 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + set_ = set("cba") + hash3 = md5(datasets.utils.py_utils.dumps(set_)).hexdigest() + self.assertEqual(hash1, hash3) + self.assertNotEqual(hash1, hash2) + @require_tiktoken def test_hash_tiktoken_encoding(self): import tiktoken