-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdpr_technique.py
59 lines (51 loc) · 2.06 KB
/
dpr_technique.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import (
DPRQuestionEncoder,
DPRContextEncoder,
DPRQuestionEncoderTokenizer,
DPRContextEncoderTokenizer,
)
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Load pre-trained DPR models and tokenizers
question_encoder = DPRQuestionEncoder.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
context_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# Encode a query
query = "capital of africa?"
question_inputs = question_tokenizer(query, return_tensors="pt")
question_embedding = question_encoder(**question_inputs).pooler_output
# Encode passages
passages = [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"Madrid is the capital of Spain.",
"Rome is the capital of Italy.",
"Maputo is the capital of Mozambique.",
"To be or not to be, that is the question.",
"The quick brown fox jumps over the lazy dog.",
"Grace Hopper was an American computer scientist and United States Navy rear admiral. who was a pioneer of computer programming, and one of the first programmers of the Harvard Mark I computer. inventor of the first compiler for a computer programming language.",
]
context_embeddings = []
for passage in passages:
context_inputs = context_tokenizer(passage, return_tensors="pt")
context_embedding = context_encoder(**context_inputs).pooler_output
context_embeddings.append(context_embedding)
context_embeddings = torch.cat(context_embeddings, dim=0)
# Compute similarities
similarities = cosine_similarity(
question_embedding.detach().numpy(), context_embeddings.detach().numpy()
)
print("Similarities:", similarities)
# Get the most relevant passage
most_relevant_idx = np.argmax(similarities)
print("Most relevant passage:", passages[most_relevant_idx])