Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary multi docs #84

Merged
merged 8 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
456 changes: 388 additions & 68 deletions ammico/multimodal_search.py

Large diffs are not rendered by default.

53 changes: 50 additions & 3 deletions ammico/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@
class SummaryDetector(AnalysisMethod):
def __init__(self, subdict: dict) -> None:
super().__init__(subdict)
self.summary_device = device("cuda" if cuda.is_available() else "cpu")
self.summary_device = "cuda" if cuda.is_available() else "cpu"

def load_model_base(self):
"""
Load base_coco blip_caption model and preprocessors for visual inputs from lavis.models.

Args:

Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="base_coco",
Expand All @@ -19,6 +28,15 @@ def load_model_base(self):
return summary_model, summary_vis_processors

def load_model_large(self):
"""
Load large_coco blip_caption model and preprocessors for visual inputs from lavis.models.

Args:

Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
summary_model, summary_vis_processors, _ = load_model_and_preprocess(
name="blip_caption",
model_type="large_coco",
Expand All @@ -27,7 +45,17 @@ def load_model_large(self):
)
return summary_model, summary_vis_processors

def load_model(self, model_type):
def load_model(self, model_type: str):
"""
Load blip_caption model and preprocessors for visual inputs from lavis.models.

Args:
model_type (str): type of the model.

Returns:
model (torch.nn.Module): model.
vis_processors (dict): preprocessors for visual inputs.
"""
select_model = {
"base": SummaryDetector.load_model_base,
"large": SummaryDetector.load_model_large,
Expand All @@ -36,6 +64,16 @@ def load_model(self, model_type):
return summary_model, summary_vis_processors

def analyse_image(self, summary_model=None, summary_vis_processors=None):
"""
Create 1 constant and 3 non deterministic captions for image.

Args:
summary_model (str): model.
summary_vis_processors (str): preprocessors for visual inputs.

Returns:
self.subdict (dict): dictionary with constant image summary and 3 non deterministic summary.
"""
if summary_model is None and summary_vis_processors is None:
summary_model, summary_vis_processors = self.load_model_base()

Expand All @@ -55,7 +93,16 @@ def analyse_image(self, summary_model=None, summary_vis_processors=None):
)
return self.subdict

def analyse_questions(self, list_of_questions):
def analyse_questions(self, list_of_questions: list[str]) -> dict:
"""
Generate answers to free-form questions about image written in natural language.

Args:
list_of_questions (list[str]): list of questions.

Returns:
self.subdict (dict): dictionary with answers to questions.
"""
(
summary_vqa_model,
summary_vqa_vis_processors,
Expand Down
25 changes: 12 additions & 13 deletions ammico/test/test_multimodal_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,24 +354,23 @@ def test_parsing_images(
tmp_path,
):
ms.MultimodalSearch.multimodal_device = pre_multimodal_device
my_obj = ms.MultimodalSearch(get_testdict)
(
model,
vis_processor,
txt_processor,
image_keys,
_,
features_image_stacked,
) = ms.MultimodalSearch.parsing_images(
get_testdict, pre_model, path_to_saved_tensors=tmp_path
)
) = my_obj.parsing_images(pre_model, path_to_save_tensors=tmp_path)

for i, num in zip(range(10), features_image_stacked[0, 10:12].tolist()):
assert (
math.isclose(num, pre_extracted_feature_img[i], rel_tol=related_error)
is True
)

test_pic = Image.open(get_testdict["IMG_2746"]["filename"]).convert("RGB")
test_pic = Image.open(my_obj.subdict["IMG_2746"]["filename"]).convert("RGB")
test_querry = (
"The bird sat on a tree located at the intersection of 23rd and 43rd streets."
)
Expand All @@ -387,10 +386,10 @@ def test_parsing_images(

search_query = [
{"text_input": test_querry},
{"image": get_testdict["IMG_2746"]["filename"]},
{"image": my_obj.subdict["IMG_2746"]["filename"]},
]
multi_features_stacked = ms.MultimodalSearch.querys_processing(
get_testdict, search_query, model, txt_processor, vis_processor, pre_model
multi_features_stacked = my_obj.querys_processing(
search_query, model, txt_processor, vis_processor, pre_model
)

for i, num in zip(range(10), multi_features_stacked[0, 10:12].tolist()):
Expand All @@ -410,8 +409,7 @@ def test_parsing_images(
{"image": get_path + "IMG_3758.png"},
]

similarity, sorted_list = ms.MultimodalSearch.multimodal_search(
get_testdict,
similarity, sorted_list = my_obj.multimodal_search(
model,
vis_processor,
txt_processor,
Expand Down Expand Up @@ -440,6 +438,7 @@ def test_parsing_images(
features_image_stacked,
processed_pic,
multi_features_stacked,
my_obj,
)
cuda.empty_cache()

Expand All @@ -452,12 +451,12 @@ def test_itm(get_test_my_dict, get_path):
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
my_obj = ms.MultimodalSearch(get_test_my_dict)
for itm_model in ["blip_base", "blip_large"]:
(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
get_test_my_dict,
) = my_obj.image_text_match_reordering(
search_query3,
itm_model,
image_keys,
Expand Down Expand Up @@ -497,12 +496,12 @@ def test_itm_blip2_coco(get_test_my_dict, get_path):
]
image_keys = ["IMG_2746", "IMG_2809"]
sorted_list = [[1, 0], [1, 0]]
my_obj = ms.MultimodalSearch(get_test_my_dict)

(
itm_scores,
image_gradcam_with_itm,
) = ms.MultimodalSearch.image_text_match_reordering(
get_test_my_dict,
) = my_obj.image_text_match_reordering(
search_query3,
"blip2_coco",
image_keys,
Expand Down
8 changes: 7 additions & 1 deletion ammico/text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from google.cloud import vision
from google.auth.exceptions import DefaultCredentialsError
from googletrans import Translator
import spacy
from spacytextblob.spacytextblob import SpacyTextBlob
Expand Down Expand Up @@ -60,7 +61,12 @@ def analyse_image(self):
def get_text_from_image(self):
"""Detects text on the image."""
path = self.subdict["filename"]
client = vision.ImageAnnotatorClient()
try:
client = vision.ImageAnnotatorClient()
except DefaultCredentialsError:
raise DefaultCredentialsError(
"Please provide credentials for google cloud vision API, see https://cloud.google.com/docs/authentication/application-default-credentials."
)
with io.open(path, "rb") as image_file:
content = image_file.read()
image = vision.Image(content=content)
Expand Down
2 changes: 1 addition & 1 deletion ammico/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def analyse_image(self):


def find_files(path=None, pattern="*.png", recursive=True, limit=20):
"""Find image files on the file system
"""Find image files on the file system.

:param path:
The base directory where we are looking for the images. Defaults
Expand Down
16 changes: 16 additions & 0 deletions docs/source/misinformation.rst → docs/source/ammico.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ cropposts module
----------------

.. automodule:: cropposts
:members:
:undoc-members:
:show-inheritance:

multimodal search module
------------------------

.. automodule:: multimodal_search
:members:
:undoc-members:
:show-inheritance:

summary module
--------------

.. automodule:: summary
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ AMMICO package modules
.. toctree::
:maxdepth: 4

misinformation
ammico
Loading