From 97925db11c01ee6e4c8c0dc763566b0f47518255 Mon Sep 17 00:00:00 2001 From: Daniel Deutsch Date: Thu, 28 Oct 2021 10:47:04 -0400 Subject: [PATCH] Adding Docker version of Prism --- gem_metrics/__init__.py | 1 + gem_metrics/prism.py | 32 ++++++++++++++++++++++++++++++++ requirements-heavy.txt | 1 + tests/test_prism.py | 17 +++++++++++++++++ 4 files changed, 51 insertions(+) create mode 100644 gem_metrics/prism.py create mode 100644 tests/test_prism.py diff --git a/gem_metrics/__init__.py b/gem_metrics/__init__.py index 4749553..769a352 100644 --- a/gem_metrics/__init__.py +++ b/gem_metrics/__init__.py @@ -50,6 +50,7 @@ def metric_list_to_metric_dict(metric_list: List[str]) -> Dict[str, List]: "sari": "SARI", "nubia": "NUBIA", "questeval": "QuestEval", + "prism": "Prism", } referenced_list, referenceless_list, sourced_and_referenced_list = [], [], [] diff --git a/gem_metrics/prism.py b/gem_metrics/prism.py new file mode 100644 index 0000000..6f195d3 --- /dev/null +++ b/gem_metrics/prism.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +from repro.models.thompson2020 import Prism as _Prism + +from .metric import ReferencedMetric + + +class Prism(ReferencedMetric): + def __init__(self, device: int, language: str = "en"): + self.prism = _Prism(device=device, language=language) + + def compute(self, cache, predictions, references): + inputs = [] + for pred, refs in zip(predictions.untokenized, references.untokenized): + inputs.append({ + "candidate": pred, + "references": refs + }) + + # Example `micro` output for 3 inputs + # [{'prism': -1.1578280925750732}, {'prism': -1.3325390815734863}, {'prism': -2.730839729309082}] + _, micro = self.prism.predict_batch(inputs) + + # Write to cache if not None and collect outputs + id_to_scores = {} + for pred_id, score_dict in zip(predictions.ids, micro): + id_to_scores[pred_id] = score_dict + if cache is not None: + cache_key = (self.__class__.__name__, predictions.filename, pred_id) + cache[cache_key] = score_dict + + return id_to_scores \ No newline at end of file diff --git a/requirements-heavy.txt b/requirements-heavy.txt index b2f9c3a..cf8d096 100644 --- a/requirements-heavy.txt +++ b/requirements-heavy.txt @@ -2,3 +2,4 @@ nubia-score bert_score git+https://github.com/google-research/bleurt.git questeval==0.2.4 +repro==0.1.2 \ No newline at end of file diff --git a/tests/test_prism.py b/tests/test_prism.py new file mode 100644 index 0000000..10fd53d --- /dev/null +++ b/tests/test_prism.py @@ -0,0 +1,17 @@ +import unittest +import gem_metrics.prism +from tests.test_referenced import TestReferencedMetric + + +class TestPrism(TestReferencedMetric, unittest.TestCase): + def setUp(self): + super().setUp() + self.metric = gem_metrics.prism.Prism(device=-1) + self.true_results_basic = {"prism": -1.7404} + self.true_results_identical_pred_ref = {"prism": -0.229} + self.true_results_mismatched_pred_ref = {"prism": -6.62654} + self.true_results_empty_pred = {"prism": -9.90867} + + +if __name__ == "__main__": + unittest.main()