From 5ea1f33f388efb64e8bc413f85a4d7f552ad09c9 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 15 Jul 2021 12:19:12 -0500 Subject: [PATCH 01/51] API: allow specification of triplets in validation --- salmon/triplets/samplers/_validation.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/salmon/triplets/samplers/_validation.py b/salmon/triplets/samplers/_validation.py index 8793f83a..c9e5a905 100644 --- a/salmon/triplets/samplers/_validation.py +++ b/salmon/triplets/samplers/_validation.py @@ -11,7 +11,7 @@ class Validation(RoundRobin): """Ask about the same queries repeatedly""" - def __init__(self, n, d=2, n_queries=20, ident=""): + def __init__(self, n, d=2, n_queries=20, queries=None, ident=""): """ This sampler asks the same questions repeatedly, useful to evaluate query difficulty. @@ -26,10 +26,23 @@ def __init__(self, n, d=2, n_queries=20, ident=""): Number of validation queries. d : int Embedding dimension. + queries : List[Tuple[int, int, int]] + The list of queries to ask about. Each query is + ``(head, obj1, obj2)`` where ``obj1`` and ``obj2`` are + randomly shown on the left and right. Each item in the tuple + is the `index` of the target to ask about. For example: + + .. code-block:: python + + queries=[(0, 1, 2), (3, 4, 5), (6, 7, 8)] + + will first ask about a query with ``head_index=0``, then + ``head_index=3``, then ``head_index=6``. """ self.n_queries = n_queries - Q = [np.random.choice(n, size=3, replace=False) for _ in range(n_queries)] - self._val_queries = Q + if queries is not None: + queries = [np.random.choice(n, size=3, replace=False) for _ in range(n_queries)] + self._val_queries = queries super().__init__(n=n, d=d, ident=ident) def get_query(self): From f9ecaa083f239ba57e3f92651a2e4cbdd0ba2211 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 15 Jul 2021 16:10:57 -0500 Subject: [PATCH 02/51] add example and debug --- examples/basic.yaml | 4 +++- salmon/frontend/public.py | 7 ++++--- salmon/triplets/samplers/_validation.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/basic.yaml b/examples/basic.yaml index 8668a4b4..bbd42329 100644 --- a/examples/basic.yaml +++ b/examples/basic.yaml @@ -4,5 +4,7 @@ max_queries: 100 samplers: ARR: {} RandomSampling: {} + Validation: + queries: [[0, 1, 2], [3, 4, 5], [6, 7, 8]] sampling: - probs: {"ARR": 80, "RandomSampling": 20} + probs: {"ARR": 40, "Validation": 40, "RandomSampling": 20} diff --git a/salmon/frontend/public.py b/salmon/frontend/public.py index 0e5f4c2a..903ce152 100644 --- a/salmon/frontend/public.py +++ b/salmon/frontend/public.py @@ -100,13 +100,14 @@ async def _ensure_initialized(): @app.get("/", tags=["public"]) -async def get_query_page(request: Request): +async def get_query_page(request: Request, puid: str=""): """ Load the query page and present a "triplet query". """ exp_config = await _ensure_initialized() - uid = "salmon-{}".format(np.random.randint(2 ** 32 - 1)) - puid = sha256(uid)[:16] + if puid == "": + uid = "salmon-{}".format(np.random.randint(2 ** 32 - 1)) + puid = sha256(uid)[:16] items = { "puid": puid, "instructions": exp_config["instructions"], diff --git a/salmon/triplets/samplers/_validation.py b/salmon/triplets/samplers/_validation.py index c9e5a905..c1696e3a 100644 --- a/salmon/triplets/samplers/_validation.py +++ b/salmon/triplets/samplers/_validation.py @@ -40,7 +40,7 @@ def __init__(self, n, d=2, n_queries=20, queries=None, ident=""): ``head_index=3``, then ``head_index=6``. """ self.n_queries = n_queries - if queries is not None: + if queries is None: queries = [np.random.choice(n, size=3, replace=False) for _ in range(n_queries)] self._val_queries = queries super().__init__(n=n, d=d, ident=ident) From d44721fdd16f443b3527a611d34f8cbfc41d3ced Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 15 Jul 2021 16:13:45 -0500 Subject: [PATCH 03/51] only run doc build on release --- .github/workflows/docs.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2b5bad39..d17713cc 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,9 +1,9 @@ name: Documentation build -on: push -# on: - # release: - # types: [published] +# on: push +on: + release: + types: [published] # Only run when release published (not created or edited, etc) # https://docs.github.com/en/actions/reference/events-that-trigger-workflows#release From 97719c7a52000a767e0771b8fb35a80c8bb5fc58 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 19 Jul 2021 14:22:58 -0500 Subject: [PATCH 04/51] add samplers_per_user --- docs/source/getting-started.rst | 11 +++++++++-- salmon/frontend/private.py | 15 ++++++++++++++- salmon/frontend/public.py | 12 +++++++----- templates/query_page.html | 14 +++++++++++++- tests/test_active.py | 27 +++++++++++++++++++++++++++ 5 files changed, 70 insertions(+), 9 deletions(-) diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst index 4cecf494..ba3cd653 100644 --- a/docs/source/getting-started.rst +++ b/docs/source/getting-started.rst @@ -103,8 +103,15 @@ in YAML jargon. Here's documentation for each key: * ``max_queries``: int. The number of queries a participant should answer. Set ``max_queries: -1`` for unlimited queries. * ``samplers``. See :ref:`adaptive-config` for more detail. -* ``sampling``. A dictionary with the key ``probs`` and percentage - probabilities for each algorithm. +* ``sampling``. A dictionary with the following keys: + + * ``probs``, a map between sampler names and the percentage that + each sampler is selected. + + * ``samplers_per_user``: (optional int, default=0). Controls the + number of samplers each user sees. If ``samplers_per_user=0``, show + users a random sampler. + * ``targets``, optional list. Choices: * YAML list. This ``targets: ["vonn", "miller", "ligety", "shiffrin"]`` is diff --git a/salmon/frontend/private.py b/salmon/frontend/private.py index 3f293b99..8f2c4cda 100644 --- a/salmon/frontend/private.py +++ b/salmon/frontend/private.py @@ -293,11 +293,24 @@ async def _get_config(exp: bytes, targets: bytes) -> Dict[str, Any]: } exp_config.update(config) if "sampling" not in exp_config: + exp_config["sampling"] = {} + + if "probs" not in exp_config["sampling"]: n = len(exp_config["samplers"]) freqs = [100 // n] * n freqs[0] += 100 % n sampling_percent = {k: f for k, f in zip(exp_config["samplers"], freqs)} - exp_config["sampling"] = {"probs": sampling_percent} + exp_config["sampling"]["probs"] = sampling_percent + + if "samplers_per_user" not in exp_config["sampling"]: + exp_config["sampling"]["samplers_per_user"] = 0 + + if exp_config["sampling"]["samplers_per_user"] not in {0, 1}: + s = exp_config["sampling"]["samplers_per_user"] + raise NotImplementedError( + "Only samplers_per_user in {0, 1} is implemented, not " + f"samplers_per_user={s}" + ) if set(exp_config["sampling"]["probs"]) != set(exp_config["samplers"]): sf = set(exp_config["sampling"]["probs"]) diff --git a/salmon/frontend/public.py b/salmon/frontend/public.py index 903ce152..52b6952b 100644 --- a/salmon/frontend/public.py +++ b/salmon/frontend/public.py @@ -116,18 +116,20 @@ async def get_query_page(request: Request, puid: str=""): "debrief": exp_config["debrief"], "skip_button": exp_config["skip_button"], "css": exp_config["css"], + "samplers_per_user": exp_config["sampling"]["samplers_per_user"], } items.update(request=request) return templates.TemplateResponse("query_page.html", items) @app.get("/query", tags=["public"]) -async def get_query() -> Dict[str, Union[int, str, float]]: - idents = rj.jsonget("samplers") - probs = rj.jsonget("sampling_probs") +async def get_query(ident="") -> Dict[str, Union[int, str, float]]: + if ident == "": + idents = rj.jsonget("samplers") + probs = rj.jsonget("sampling_probs") - idx = np.random.choice(len(idents), p=probs) - ident = idents[idx] + idx = np.random.choice(len(idents), p=probs) + ident = idents[idx] r = httpx.get(f"http://localhost:8400/query-{ident}") if r.status_code == 200: diff --git a/templates/query_page.html b/templates/query_page.html index e85b2451..b42b04f2 100644 --- a/templates/query_page.html +++ b/templates/query_page.html @@ -151,6 +151,7 @@ var max_queries = {{ max_queries }}; var num_queries = 0; // queries submitted var show_queries = getTime(); +var samplers_per_user = {{ samplers_per_user }}; async function send(winner, id) { setTimeout(function() { @@ -210,9 +211,20 @@ } var prev_queries = new FixedLengthArray(5); +var ident = ""; function getquery(){ - $.get("/query", function(data){ + console.log(ident); + if ((ident == "") | (samplers_per_user == 0)){ + var endpoint = "/query" + } else if (samplers_per_user == 1) { + var endpoint = "/query?ident=" + ident; + } else { + var msg ="samplers_per_user=" + samplers_per_user + " is not implemented"; + console.log(msg); + alert(msg); + } + $.get(endpoint, function(data){ head = data["head"]; right = data["right"]; left = data["left"]; diff --git a/tests/test_active.py b/tests/test_active.py index df92e6f5..bb405c14 100644 --- a/tests/test_active.py +++ b/tests/test_active.py @@ -74,6 +74,33 @@ def test_active_basics(server, logs): assert set(algs) == {"TSTE", "ARR", "CKL", "tste2", "GNMDS"} +def test_samplers_per_user(server, logs): + server.authorize() + exp = Path(__file__).parent / "data" / "active.yaml" + print("init'ing exp") + exp2 = yaml.safe_load(exp.read_bytes()) + exp2["sampling"] = {"samplers_per_user": 1} + server.post("/init_exp", data={"exp": str(exp2)}) + print("done") + + with open(exp, "r") as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + samplers = list(config["samplers"].keys()) + + ident = random.choice(samplers) + with logs: + for k in range(len(samplers) * 2): + q = server.get(f"/query?ident={ident}").json() + ans = {"winner": random.choice([q["left"], q["right"]]), "puid": "foo", **q} + server.post("/answer", json=ans) + + r = server.get("/responses") + d = r.json() + df = pd.DataFrame(d) + algs = df.alg_ident.unique() + assert len(set(algs)) == 1 + + def test_round_robin(server, logs): server.authorize() exp = Path(__file__).parent / "data" / "round-robin.yaml" From d9a94b38f3ea0be28257d3beaf520bab66de9131 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 19 Jul 2021 15:13:54 -0500 Subject: [PATCH 05/51] redirect --- .github/workflows/docs.yml | 8 ++++---- docs/source/_static/alieneggs.html | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d17713cc..2b5bad39 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,9 +1,9 @@ name: Documentation build -# on: push -on: - release: - types: [published] +on: push +# on: + # release: + # types: [published] # Only run when release published (not created or edited, etc) # https://docs.github.com/en/actions/reference/events-that-trigger-workflows#release diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 7f97ecbf..671f2ca0 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -1,10 +1,10 @@ - - - - -

Please enable Javascript for to be randomly redirected.

- - + + + + + +
+Redirecting to http://34.223.104.82:8421/... +
From e38472a884e0b3da448d82e51a99dfe33702f948 Mon Sep 17 00:00:00 2001 From: Scott Date: Thu, 22 Jul 2021 14:06:42 -0500 Subject: [PATCH 06/51] another redirect --- docs/source/_static/alieneggs.html | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 671f2ca0..56e2a185 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -1,10 +1,24 @@ + - + + - +
+

Please enable Javascript for a (random) redirection

+
-
-Redirecting to http://34.223.104.82:8421/... -
+ + + From 72416e46dae643de72231f2899c5735944243317 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 23 Jul 2021 13:46:16 -0500 Subject: [PATCH 07/51] redirect --- docs/source/_static/alieneggs.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 56e2a185..934493ea 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -8,7 +8,7 @@ From dce1226a55a7334f691bf3298caff8ba31e54543 Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 28 Jul 2021 13:52:49 -0500 Subject: [PATCH 14/51] redirect --- docs/source/_static/alieneggs.html | 33 +++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 934493ea..9e40316b 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -8,16 +8,35 @@ From 472766723bd7229cd2748c13facf669d36440460 Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 28 Jul 2021 14:23:45 -0500 Subject: [PATCH 15/51] slightly edits probs --- docs/source/_static/alieneggs.html | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 9e40316b..c9c90c57 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -8,10 +8,10 @@ From dfabb2e7e5a038602df54b8adce2ad84e3c5f8fe Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 28 Jul 2021 14:50:59 -0500 Subject: [PATCH 17/51] woah, *properly* redirect --- docs/source/_static/alieneggs.html | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 4cf2da44..d7d2b5ae 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -17,7 +17,6 @@ var urls = ["http://34.222.189.214:8421/", "http://35.84.133.58:8421/", "http://35.80.6.172:8421/", "http://35.83.252.68:8421/", "http://44.233.116.46:8421/"]; out = []; - for (let i=0; i<1000; i++){ var r = Math.random(); if (r < prob_m1){ @@ -32,9 +31,6 @@ url = urls[4]; } console.log(url); - out.push(url); - } - console.log(out); //1 0.121 // 0.105 //2 0.064 // 0.052 From df674c2570e911156792638e916b5db0c4723e86 Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 28 Jul 2021 14:57:32 -0500 Subject: [PATCH 18/51] *actually* properly redirect --- docs/source/_static/alieneggs.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index d7d2b5ae..1993899f 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -31,6 +31,8 @@ url = urls[4]; } console.log(url); + document.getElementById("msg").innerHTML = "

Redirecting to " + url + "

"; + window.open(url, '_self'); //1 0.121 // 0.105 //2 0.064 // 0.052 @@ -45,8 +47,6 @@ //http://35.84.133.58:8421/ 0.069 - //document.getElementById("msg").innerHTML = "

Redirecting to " + url + "

"; - //window.open(url, '_self'); From a97c4885aeaf78a08de03c38289cdd7d0ce8e6fc Mon Sep 17 00:00:00 2001 From: Scott Date: Wed, 28 Jul 2021 15:26:16 -0500 Subject: [PATCH 19/51] typo --- docs/source/_static/alieneggs.html | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/_static/alieneggs.html b/docs/source/_static/alieneggs.html index 1993899f..29663398 100644 --- a/docs/source/_static/alieneggs.html +++ b/docs/source/_static/alieneggs.html @@ -16,7 +16,6 @@ var urls = ["http://34.222.189.214:8421/", "http://35.84.133.58:8421/", "http://35.80.6.172:8421/", "http://35.83.252.68:8421/", "http://44.233.116.46:8421/"]; - out = []; var r = Math.random(); if (r < prob_m1){ @@ -30,7 +29,6 @@ } else { url = urls[4]; } - console.log(url); document.getElementById("msg").innerHTML = "

Redirecting to " + url + "

"; window.open(url, '_self'); @@ -46,6 +44,11 @@ //http://34.222.189.214:8421/ 0.094 //http://35.84.133.58:8421/ 0.069 +//http://35.80.6.172:8421/ 17 +//http://44.233.116.46:8421/ 17 +//http://35.83.252.68:8421/ 5 +//http://34.222.189.214:8421/ 5 +//http://35.84.133.58:8421/ 2 From d7a41480e4373bd6e1f227c5cc52320a79c0682c Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 30 Jul 2021 12:13:30 -0500 Subject: [PATCH 20/51] ENH: allow initializing embedding --- salmon/triplets/offline.py | 3 ++- salmon/triplets/samplers/_adaptive_runners.py | 6 +++++- salmon/triplets/samplers/adaptive/_embed.py | 12 ++++++++++-- tests/test_offline.py | 13 +++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/salmon/triplets/offline.py b/salmon/triplets/offline.py index 99f17030..c274ac85 100644 --- a/salmon/triplets/offline.py +++ b/salmon/triplets/offline.py @@ -109,7 +109,7 @@ def embedding_(self): """ return self.opt_.embedding_ - def initialize(self, X_train): + def initialize(self, X_train, embedding=None): """ Initialize this optimizer. @@ -134,6 +134,7 @@ def initialize(self, X_train): opt = OGD(**kwargs) else: opt = self.opt + opt.initialize(embedding=embedding) opt.push(X_train) self._meta: Dict[str, Number] = {"pf_calls": 0} diff --git a/salmon/triplets/samplers/_adaptive_runners.py b/salmon/triplets/samplers/_adaptive_runners.py index 4866884b..3227fb48 100644 --- a/salmon/triplets/samplers/_adaptive_runners.py +++ b/salmon/triplets/samplers/_adaptive_runners.py @@ -409,6 +409,7 @@ class SARR(ARR): A adaptive round robin scheme that runs a synchronous search, a modification of :class:`~salmon.triplets.samplers.ARR`. """ + def __init__(self, *args, n_search=400, **kwargs): """ Parameters @@ -432,7 +433,10 @@ def get_query(self): head = int(np.random.choice(self.n)) _choices = list(set(range(self.n)) - {head}) choices = np.array(_choices) - bottoms = [np.random.choice(choices, size=2, replace=False) for _ in range(self.n_search)] + bottoms = [ + np.random.choice(choices, size=2, replace=False) + for _ in range(self.n_search) + ] _queries = [[head, l, r] for l, r in bottoms] queries, scores = self.search.score(queries=_queries) diff --git a/salmon/triplets/samplers/adaptive/_embed.py b/salmon/triplets/samplers/adaptive/_embed.py index c28a7351..4587470d 100644 --- a/salmon/triplets/samplers/adaptive/_embed.py +++ b/salmon/triplets/samplers/adaptive/_embed.py @@ -1,6 +1,6 @@ import itertools from copy import deepcopy -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import numpy as np import torch @@ -77,7 +77,7 @@ def __init__( self.initial_batch_size = initial_batch_size self.kwargs = kwargs - def initialize(self): + def initialize(self, embedding: Optional[np.ndarray] = None): """ Initialize this optimization algorithm. """ @@ -99,6 +99,14 @@ def initialize(self): train_split=None, dataset=NumpyDataset, ).initialize() + if embedding is not None: + if not isinstance(embedding, np.ndarray): + raise ValueError( + f"Specify embedding as a NumPy array, not a {type(embedding)}" + ) + with torch.no_grad(): + em = torch.from_numpy(embedding.astype("float32")) + self.net_.module_.embedding.data = em return self # def converged(self): diff --git a/tests/test_offline.py b/tests/test_offline.py index 91039577..c023c754 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -85,3 +85,16 @@ def test_offline_embedding_random_state(): n=n, d=d, max_epochs=max_epochs, random_state=random_state ).initialize(X_train) assert np.allclose(m1.embedding_, m2.embedding_) + + +def test_offline_init(): + n, d = 20, 2 + + X = np.random.choice(n, size=(100, 3)) + em = np.random.uniform(size=(n, d)) + est = OfflineEmbedding(n=n, d=d) + est.initialize(X, embedding=em) + + assert np.allclose(est.embedding_, em) + est.partial_fit(X) + assert not np.allclose(est.embedding_, em), "Embedding didn't change" From 12ed61e0fcbb2e9db9bc00fc76aff14353157ed6 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 30 Jul 2021 13:35:57 -0500 Subject: [PATCH 21/51] BUG: allow some samplers not to have embeddings --- salmon/frontend/private.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/salmon/frontend/private.py b/salmon/frontend/private.py index b663ae00..2dbee55a 100644 --- a/salmon/frontend/private.py +++ b/salmon/frontend/private.py @@ -708,7 +708,16 @@ async def get_embeddings( exp_config = deepcopy(exp_config) targets = exp_config.pop("targets") alg_idents = list(exp_config.pop("samplers").keys()) - embeddings = {alg: await get_model(alg) for alg in alg_idents} + embeddings = {} + for alg in alg_idents: + try: + embeddings[alg] = await get_model(alg) + except: + pass + if len(embeddings) == 0: + raise ServerException( + f"No model has been created for any sampler in {alg_idents}" + ) dfs = { alg: _fmt_embedding(model["embedding"], targets, alg=alg) for alg, model in embeddings.items() From bfa29c4655930c8cd7c42fcc6d146e5713f91ad8 Mon Sep 17 00:00:00 2001 From: Scott Date: Fri, 30 Jul 2021 13:39:09 -0500 Subject: [PATCH 22/51] DOC: show usage of initialize --- docs/source/offline.rst | 6 +++++- salmon/triplets/offline.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/source/offline.rst b/docs/source/offline.rst index 2e514762..51fbc89b 100644 --- a/docs/source/offline.rst +++ b/docs/source/offline.rst @@ -45,14 +45,18 @@ This code will generate an embedding: from salmon.triplets.offline import OfflineEmbedding - df = pd.read_csv("responses.csv") + df = pd.read_csv("responses.csv") # from dashboard X = df[["head", "winner", "loser"]].to_numpy() + em = pd.read_csv("embedding.csv") # from dashboard + n = int(X.max() + 1) # number of targets d = 2 # embed into 2 dimensions X_train, X_test = train_test_split(X, random_state=42, test_size=0.2) model = OfflineEmbedding(n=n, d=d) + model.initialize(X_train, embedding=em.to_numpy()) + model.fit(X_train, X_test) model.embedding_ # embedding diff --git a/salmon/triplets/offline.py b/salmon/triplets/offline.py index c274ac85..4b8f9de4 100644 --- a/salmon/triplets/offline.py +++ b/salmon/triplets/offline.py @@ -118,6 +118,25 @@ def initialize(self, X_train, embedding=None): X_train : np.ndarray Responses organized to be [head, winner, loser]. + embedding : nd.ndarray, optional + If specified, initialize the embedding with the given values. + + .. note:: + + This is particularly useful when ``embedding`` is the + online embedding from the CSV: + + .. code-block:: python + + import pandas as pd + em = pd.read_csv("embedding.csv") # from dashboard + df = pd.read_csv("responses.csv") # from dashboard + X = df[["head", "winner", "loser"]] + + from salmon.triplets.offline import OfflineEmbedding + est = OfflineEmbedding(...) + est.initialize(X, embedding=em) + """ if self.opt is None: assert self.n is not None and self.d is not None, "Specify n and d" From 3970c20e047efd6505bc9e7b405504cdb55f408b Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 2 Aug 2021 17:28:05 -0500 Subject: [PATCH 23/51] add embedding init --- salmon/triplets/offline.py | 52 ++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/salmon/triplets/offline.py b/salmon/triplets/offline.py index 4b8f9de4..1963a583 100644 --- a/salmon/triplets/offline.py +++ b/salmon/triplets/offline.py @@ -121,22 +121,6 @@ def initialize(self, X_train, embedding=None): embedding : nd.ndarray, optional If specified, initialize the embedding with the given values. - .. note:: - - This is particularly useful when ``embedding`` is the - online embedding from the CSV: - - .. code-block:: python - - import pandas as pd - em = pd.read_csv("embedding.csv") # from dashboard - df = pd.read_csv("responses.csv") # from dashboard - X = df[["head", "winner", "loser"]] - - from salmon.triplets.offline import OfflineEmbedding - est = OfflineEmbedding(...) - est.initialize(X, embedding=em) - """ if self.opt is None: assert self.n is not None and self.d is not None, "Specify n and d" @@ -179,7 +163,7 @@ def partial_fit(self, X_train): self._partial_fit(X_train) return self - def fit(self, X_train, X_test): + def fit(self, X_train, X_test, embedding=None): """ Fit the embedding with train and validation data. @@ -196,21 +180,38 @@ def fit(self, X_train, X_test): The responses with shape ``(n_questions, 3)``. Each question is organized as ``[head, winner, loser]``. + + embedding : np.ndarray, optional + The embedding to initialize with. + + .. note:: + + This is particularly useful when ``embedding`` is the + online embedding from the CSV: + + .. code-block:: python + + import pandas as pd + em = pd.read_csv("embedding.csv") # from dashboard + df = pd.read_csv("responses.csv") # from dashboard + X = df[["head", "winner", "loser"]] + + from salmon.triplets.offline import OfflineEmbedding + est = OfflineEmbedding(...) + est.initialize(X, embedding=em) + """ - self.initialize(X_train) + self.initialize(X_train, embedding=embedding) self._meta["pf_calls"] = 0 _start = time() for k in itertools.count(): self._partial_fit(X_train) if self.verbose and k == 0: print(self.opt_.optimizer, self.opt_.get_params()) - if self.opt_.meta_["num_grad_comps"] >= self.max_epochs * len(X_train): - break - - if self.verbose and ( - k % self.verbose == 0 or abs(self.max_epochs - k) <= 3 - ): + if (self.verbose and k % self.verbose == 0) or abs( + self.max_epochs - k + ) <= 3: datum = deepcopy(self._meta) datum.update(self.opt_.meta_) test_score, loss_test = self._score(X_test) @@ -227,7 +228,8 @@ def fit(self, X_train, X_test): self._history_.append(datum) if self.verbose and k % self.verbose == 0: print(show) - + if self.opt_.meta_["num_grad_comps"] >= self.max_epochs * len(X_train): + break test_score, loss_test = self._score(X_test) self._history_[-1]["score_test"] = test_score self._history_[-1]["loss_test"] = loss_test From 9c083e501b19e4fd97950d8eaef456fc627bacf2 Mon Sep 17 00:00:00 2001 From: Scott Date: Mon, 2 Aug 2021 18:46:39 -0500 Subject: [PATCH 24/51] bump --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d0b06fc7..7f5fbc2a 100644 --- a/README.md +++ b/README.md @@ -5,4 +5,5 @@ Documentation status badge + See the documentation for more detail: https://docs.stsievert.com/salmon/ From 7556c1af377decaf3193e5450aae6a3bf32a0371 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 17 Aug 2021 15:50:13 -0500 Subject: [PATCH 25/51] Update continuumio/miniconda3 docker version --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 06bf75b7..d9321995 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM continuumio/miniconda3:4.9.2 +FROM continuumio/miniconda3:4.10.3 RUN apt-get update RUN apt-get install -y gcc cmake g++ From c120a16d155f5b647dcbc40f8881f1240aeede93 Mon Sep 17 00:00:00 2001 From: Scott Date: Tue, 17 Aug 2021 18:08:22 -0500 Subject: [PATCH 26/51] alpha preload --- salmon/frontend/public.py | 7 ++++++- templates/query_page.html | 34 +++++++++++++++++++++------------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/salmon/frontend/public.py b/salmon/frontend/public.py index 52b6952b..764228ed 100644 --- a/salmon/frontend/public.py +++ b/salmon/frontend/public.py @@ -19,7 +19,7 @@ from ..triplets import manager from ..utils import get_logger -from .utils import ServerException, sha256 +from .utils import ServerException, sha256, image_url logger = get_logger(__name__) @@ -108,6 +108,10 @@ async def get_query_page(request: Request, puid: str=""): if puid == "": uid = "salmon-{}".format(np.random.randint(2 ** 32 - 1)) puid = sha256(uid)[:16] + try: + urls = [image_url(t) for t in exp_config["targets"]] + except: + urls = [] items = { "puid": puid, "instructions": exp_config["instructions"], @@ -117,6 +121,7 @@ async def get_query_page(request: Request, puid: str=""): "skip_button": exp_config["skip_button"], "css": exp_config["css"], "samplers_per_user": exp_config["sampling"]["samplers_per_user"], + "urls": urls, } items.update(request=request) return templates.TemplateResponse("query_page.html", items) diff --git a/templates/query_page.html b/templates/query_page.html index 0d65c537..984f5a4a 100644 --- a/templates/query_page.html +++ b/templates/query_page.html @@ -13,6 +13,10 @@ Salmon + +{% for url in urls %} + +{% endfor %}