-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathcross_encoder.py
73 lines (52 loc) · 2.47 KB
/
cross_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# good resources
# https://qdrant.tech/articles/hybrid-search/
# https://www.sbert.net/examples/applications/semantic-search/README.html
import asyncio
import itertools
from pgvector.psycopg import register_vector_async
import psycopg
from sentence_transformers import CrossEncoder, SentenceTransformer
sentences = [
'The dog is barking',
'The cat is purring',
'The bear is growling'
]
query = 'growling bear'
async def create_schema(conn):
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
await register_vector_async(conn)
await conn.execute('DROP TABLE IF EXISTS documents')
await conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))')
await conn.execute("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))")
async def insert_data(conn):
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
embeddings = model.encode(sentences)
sql = 'INSERT INTO documents (content, embedding) VALUES ' + ', '.join(['(%s, %s)' for _ in embeddings])
params = list(itertools.chain(*zip(sentences, embeddings)))
await conn.execute(sql, params)
async def semantic_search(conn, query):
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
embedding = model.encode(query)
async with conn.cursor() as cur:
await cur.execute('SELECT id, content FROM documents ORDER BY embedding <=> %s LIMIT 5', (embedding,))
return await cur.fetchall()
async def keyword_search(conn, query):
async with conn.cursor() as cur:
await cur.execute("SELECT id, content FROM documents, plainto_tsquery('english', %s) query WHERE to_tsvector('english', content) @@ query ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC LIMIT 5", (query,))
return await cur.fetchall()
def rerank(query, results):
# deduplicate
results = set(itertools.chain(*results))
# re-rank
encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
scores = encoder.predict([(query, item[1]) for item in results])
return [v for _, v in sorted(zip(scores, results), reverse=True)]
async def main():
conn = await psycopg.AsyncConnection.connect(dbname='pgvector_example', autocommit=True)
await create_schema(conn)
await insert_data(conn)
# perform queries in parallel
results = await asyncio.gather(semantic_search(conn, query), keyword_search(conn, query))
results = rerank(query, results)
print(results)
asyncio.run(main())