Skip to content

Commit

Permalink
Update text_classification.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtrung23 authored Sep 19, 2024
1 parent 64375a8 commit 6e154c1
Showing 1 changed file with 26 additions and 59 deletions.
85 changes: 26 additions & 59 deletions src/melt/tools/metrics/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,57 @@
"""Module for evaluating text classification models."""

from typing import Dict, Tuple
"test_classification"
from typing import Dict
import numpy as np
import evaluate
from sklearn.metrics import (
f1_score as f1_score_sklearn,
accuracy_score,
roc_auc_score,
)
from .utils import normalize_text
from .post_process import softmax_options_prob
from .base import BaseMetric

from melt.tools.metrics.utils import normalize_text
from melt.tools.metrics.post_process import softmax_options_prob
from melt.tools.metrics.base import BaseMetric
class TextClassificationMetric(BaseMetric):
"""Evaluate text classification models."""

def __init__(self, data, args):
super().__init__(data, args)
# Ensure 'evaluate' is correctly installed and used, or remove if not needed
self.roc_auc_score = None # Remove if not used
self.data =data

def evaluate(self, data: Dict, args) -> Tuple[Dict, Dict]:
self.roc_auc_score = evaluate.load("roc_auc", "multiclass")
def evaluate(self, data: Dict, args) -> tuple[Dict, Dict]:
"""Evaluates the classification performance
given the predictions, references, and additional arguments.
Args:
data (Dict): A dictionary expected to contain keys
like predictions, references, and option_probs.
args: Additional arguments including class_names.
Returns:
Tuple[Dict, Dict]: The original data dictionary and
Returns a tuple containing the original data dictionary and
the result dictionary with all the computed metrics.
"""
result = {}
raw_predictions = data["predictions"]
args.class_names = [normalize_text(str(name)) for name in args.class_names]
predictions = [
str(self._get_answer(raw_prediction, args))
for raw_prediction in raw_predictions
]
references = self._normalize_references(data["references"], args)

predictions = [str(self._get_answer(raw_prediction, args))
for raw_prediction in data["predictions"]]
references = self._process_references(data["references"], predictions)
result["accuracy"] = accuracy_score(references, predictions)
result["f1_score"] = f1_score_sklearn(
references, predictions, average="macro"
)

sum_option_probs = [
[np.array(x).sum() for x in probs]
for probs in data["option_probs"]
]

result["f1_score"] = f1_score_sklearn(references, predictions, average="macro")
sum_option_probs = [[np.array(x).sum() for x in option_prob]
for option_prob in data["option_probs"]]
probs = softmax_options_prob(sum_option_probs)
if len(args.class_names) == 2:
probs = probs[:, 1].reshape(-1, 1)
labels = np.array([
args.class_names.index(ref) for ref in references
])

labels = np.array([args.class_names.index(ref) for ref in references])
try:
result["roc_auc"] = roc_auc_score(
labels, probs, multi_class="ovr", average="macro"
)
except (ValueError, TypeError, IndexError) as e:
print(f"Error calculating ROC AUC: {e}")
result["roc_auc"] = roc_auc_score(labels, probs, multi_class="ovr", average="macro")
except ValueError as e:
print(f"ROC AUC calculation failed: {e}")
result["roc_auc"] = None

return data, result
def reset_data(self, new_data):
"""Resets the data with new data."""
self.data = new_data
def _normalize_references(self, references, args):
"""Helper function to normalize references."""

normalized_references = []
for reference in references:
def _process_references(self, references, predictions):
processed_references = []
for reference, prediction in zip(references, predictions):
if isinstance(reference, list):
reference = [normalize_text(str(ref)) for ref in reference]
first_ref = str(normalize_text(reference[0]))
answer = self._get_answer(reference, args)
if answer in reference:
normalized_references.append(first_ref)
else:
normalized_references.append(str(reference[0]))
processed_references.append(str(normalize_text(prediction)
if prediction in reference else reference[0]))
else:
normalized_references.append(normalize_text(str(reference)))
return list(normalized_references)
processed_references.append(normalize_text(str(reference)))
return processed_references

0 comments on commit 6e154c1

Please sign in to comment.