-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathText2Image_search.py
35 lines (27 loc) · 1.41 KB
/
Text2Image_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
from haystack import Document
from haystack import Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes.retriever.multimodal import MultiModalRetriever
class MultimodalSearch:
def __init__(self):
self.document_store = InMemoryDocumentStore(embedding_dim=512)
doc_dir = "Data"
images = [
Document(content=f"./{doc_dir}/{filename}", content_type="image")
for filename in os.listdir(f"./{doc_dir}")
]
self.document_store.write_documents(images)
self.retriever_text_to_image = MultiModalRetriever(
document_store=self.document_store,
query_embedding_model="sentence-transformers/clip-ViT-B-32",
query_type="text",
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
)
# Turn images into embeddings and store them in the DocumentStore
self.document_store.update_embeddings(retriever=self.retriever_text_to_image)
self.pipeline = Pipeline()
self.pipeline.add_node(component=self.retriever_text_to_image, name="retriever_text_to_image", inputs=["Query"])
def search(self, query, top_k=3):
results = self.pipeline.run(query=query, params={"retriever_text_to_image": {"top_k": top_k}})
return sorted(results["documents"], key=lambda d: d.score, reverse=True)