-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* match torchscorer tqdm * initial commit of the vector based prf * cleanup * cleanup * reworked prf as transformers, added a simple test * using ``pta.transform.by_query`` * pull forward all query columns * bibtex block in documentation * documentation * whoops * more tests --------- Co-authored-by: Sean MacAvaney <sean.macavaney@gmail.com>
- Loading branch information
1 parent
6909587
commit 6fe2564
Showing
7 changed files
with
302 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pyterrier as pt | ||
import pyterrier_alpha as pta | ||
|
||
|
||
class VectorPrf(pt.Transformer): | ||
""" | ||
Performs a Rocchio-esque PRF by linearly combining the query_vec column with | ||
the doc_vec column of the top k documents. | ||
Arguments: | ||
- alpha: weight of original query_vec | ||
- beta: weight of doc_vec | ||
- k: number of pseudo-relevant feedback documents | ||
Expected Input Columns: ``['qid', 'query_vec', 'docno', 'doc_vec']`` | ||
Output Columns: ``['qid', 'query_vec']`` (Any other query columns from the input are also pulled included in the output.) | ||
Example:: | ||
prf_pipe = model >> index >> index.vec_loader() >> pyterier_dr.vector_prf() >> index | ||
.. code-block:: bibtex | ||
:caption: Citation | ||
@article{DBLP:journals/tois/0009MZKZ23, | ||
author = {Hang Li and | ||
Ahmed Mourad and | ||
Shengyao Zhuang and | ||
Bevan Koopman and | ||
Guido Zuccon}, | ||
title = {Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: | ||
Successes and Pitfalls}, | ||
journal = {{ACM} Trans. Inf. Syst.}, | ||
volume = {41}, | ||
number = {3}, | ||
pages = {62:1--62:40}, | ||
year = {2023}, | ||
url = {https://doi.org/10.1145/3570724}, | ||
doi = {10.1145/3570724}, | ||
timestamp = {Fri, 21 Jul 2023 22:26:51 +0200}, | ||
biburl = {https://dblp.org/rec/journals/tois/0009MZKZ23.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
""" | ||
def __init__(self, | ||
*, | ||
alpha: float = 1, | ||
beta: float = 0.2, | ||
k: int = 3 | ||
): | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.k = k | ||
|
||
@pta.transform.by_query(add_ranks=False) | ||
def transform(self, inp: pd.DataFrame) -> pd.DataFrame: | ||
"""Performs Vector PRF on the input dataframe.""" | ||
pta.validate.result_frame(inp, extra_columns=['query_vec', 'doc_vec']) | ||
|
||
query_cols = [col for col in inp.columns if col.startswith('q') and col != 'query_vec'] | ||
|
||
# get the docvectors for the top k docs | ||
doc_vecs = np.stack([ row.doc_vec for row in inp.head(self.k).itertuples() ]) | ||
# combine their average and add to the query | ||
query_vec = self.alpha * inp['query_vec'].iloc[0] + self.beta * np.mean(doc_vecs, axis=0) | ||
# generate new query dataframe with the existing query columns and the new query_vec | ||
return pd.DataFrame([[inp[c].iloc[0] for c in query_cols] + [query_vec]], columns=query_cols + ['query_vec']) | ||
|
||
def __repr__(self): | ||
return f"VectorPrf(alpha={self.alpha}, beta={self.beta}, k={self.k})" | ||
|
||
|
||
class AveragePrf(pt.Transformer): | ||
""" | ||
Performs Average PRF (as described by Li et al.) by averaging the query_vec column with | ||
the doc_vec column of the top k documents. | ||
Arguments: | ||
- k: number of pseudo-relevant feedback documents | ||
Expected Input Columns: ``['qid', 'query_vec', 'docno', 'doc_vec']`` | ||
Output Columns: ``['qid', 'query_vec']`` (Any other query columns from the input are also pulled included in the output.) | ||
Example:: | ||
prf_pipe = model >> index >> index.vec_loader() >> pyterier_dr.average_prf() >> index | ||
.. code-block:: bibtex | ||
:caption: Citation | ||
@article{DBLP:journals/tois/0009MZKZ23, | ||
author = {Hang Li and | ||
Ahmed Mourad and | ||
Shengyao Zhuang and | ||
Bevan Koopman and | ||
Guido Zuccon}, | ||
title = {Pseudo Relevance Feedback with Deep Language Models and Dense Retrievers: | ||
Successes and Pitfalls}, | ||
journal = {{ACM} Trans. Inf. Syst.}, | ||
volume = {41}, | ||
number = {3}, | ||
pages = {62:1--62:40}, | ||
year = {2023}, | ||
url = {https://doi.org/10.1145/3570724}, | ||
doi = {10.1145/3570724}, | ||
timestamp = {Fri, 21 Jul 2023 22:26:51 +0200}, | ||
biburl = {https://dblp.org/rec/journals/tois/0009MZKZ23.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
""" | ||
def __init__(self, | ||
*, | ||
k: int = 3 | ||
): | ||
self.k = k | ||
|
||
@pta.transform.by_query(add_ranks=False) | ||
def transform(self, inp: pd.DataFrame) -> pd.DataFrame: | ||
"""Performs Average PRF on the input dataframe.""" | ||
pta.validate.result_frame(inp, extra_columns=['query_vec', 'doc_vec']) | ||
|
||
query_cols = [col for col in inp.columns if col.startswith('q') and col != 'query_vec'] | ||
|
||
# get the docvectors for the top k docs and the query_vec | ||
all_vecs = np.stack([inp['query_vec'].iloc[0]] + [row.doc_vec for row in inp.head(self.k).itertuples()]) | ||
# combine their average and add to the query | ||
query_vec = np.mean(all_vecs, axis=0) | ||
# generate new query dataframe with the existing query columns and the new query_vec | ||
return pd.DataFrame([[inp[c].iloc[0] for c in query_cols] + [query_vec]], columns=query_cols + ['query_vec']) | ||
|
||
def __repr__(self): | ||
return f"AveragePrf(k={self.k})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
Pseudo Relevance Feedback (PRF) | ||
=============================== | ||
|
||
Dense Pseudo Relevance Feedback (PRF) is a technique to improve the performance of a retrieval system by expanding the | ||
original query vector with the vectors from the top-ranked documents. The idea is that the top-ranked documents. | ||
|
||
PyTerrier-DR provides two dense PRF implementations: :class:`pyterrier_dr.AveragePrf` and :class:`pyterrier_dr.VectorPrf`. | ||
|
||
API Documentation | ||
----------------- | ||
|
||
.. autoclass:: pyterrier_dr.AveragePrf | ||
:members: | ||
|
||
.. autoclass:: pyterrier_dr.VectorPrf | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import unittest | ||
import numpy as np | ||
import pandas as pd | ||
from pyterrier_dr import AveragePrf, VectorPrf | ||
|
||
|
||
class TestModels(unittest.TestCase): | ||
|
||
def test_average_prf(self): | ||
prf = AveragePrf() | ||
with self.subTest('single row'): | ||
inp = pd.DataFrame([['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])]], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
self.assertEqual(out['qid'].iloc[0], 'q1') | ||
self.assertEqual(out['query'].iloc[0], 'query') | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([2.5, 3.5, 4.5])) | ||
|
||
with self.subTest('multiple rows'): | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([3.5, 4.5, 3.])) | ||
|
||
with self.subTest('multiple rows -- k=3'): | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd4', np.array([100, 100, 100])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([3.5, 4.5, 3.])) | ||
|
||
with self.subTest('multiple rows -- k=1'): | ||
prf.k = 1 | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd4', np.array([100, 100, 100])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([2.5, 3.5, 4.5])) | ||
|
||
with self.subTest('multiple queries'): | ||
prf.k = 3 | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([8, 7, 1])], | ||
['q2', 'query2', np.array([4, 6, 1]), 'd1', np.array([9, 4, 2])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 2) | ||
self.assertEqual(out['qid'].iloc[0], 'q1') | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([3.5, 4.5, 3.])) | ||
self.assertEqual(out['qid'].iloc[1], 'q2') | ||
np.testing.assert_array_equal(out['query_vec'].iloc[1], np.array([6.5, 5., 1.5])) | ||
|
||
def test_vector_prf(self): | ||
prf = VectorPrf(alpha=0.5, beta=0.5) | ||
with self.subTest('single row'): | ||
inp = pd.DataFrame([['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])]], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
self.assertEqual(out['qid'].iloc[0], 'q1') | ||
self.assertEqual(out['query'].iloc[0], 'query') | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([2.5, 3.5, 4.5])) | ||
|
||
with self.subTest('multiple rows'): | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_almost_equal(out['query_vec'].iloc[0], np.array([2.666667, 3.666667, 3.]), decimal=5) | ||
|
||
with self.subTest('multiple rows -- k=3'): | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd4', np.array([100, 100, 100])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_almost_equal(out['query_vec'].iloc[0], np.array([2.666667, 3.666667, 3.]), decimal=5) | ||
|
||
with self.subTest('multiple rows -- k=1'): | ||
prf.k = 1 | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd2', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd3', np.array([8, 7, 1])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd4', np.array([100, 100, 100])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 1) | ||
np.testing.assert_array_equal(out['query_vec'].iloc[0], np.array([2.5, 3.5, 4.5])) | ||
|
||
with self.subTest('multiple queries'): | ||
prf.k = 3 | ||
inp = pd.DataFrame([ | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([4, 5, 6])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([1, 4, 2])], | ||
['q1', 'query', np.array([1, 2, 3]), 'd1', np.array([8, 7, 1])], | ||
['q2', 'query2', np.array([4, 6, 1]), 'd1', np.array([9, 4, 2])], | ||
], columns=['qid', 'query', 'query_vec', 'docno', 'doc_vec']) | ||
out = prf(inp) | ||
self.assertEqual(out.columns.tolist(), ['qid', 'query', 'query_vec']) | ||
self.assertEqual(len(out), 2) | ||
self.assertEqual(out['qid'].iloc[0], 'q1') | ||
np.testing.assert_almost_equal(out['query_vec'].iloc[0], np.array([2.666667, 3.666667, 3.]), decimal=5) | ||
self.assertEqual(out['qid'].iloc[1], 'q2') | ||
np.testing.assert_array_equal(out['query_vec'].iloc[1], np.array([6.5, 5., 1.5])) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |