Skip to content

Commit

Permalink
Fix convention for src/melt/tools/metrics/name_detector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtrung23 committed Sep 8, 2024
1 parent 5e05d5a commit 4eb2711
Showing 1 changed file with 23 additions and 75 deletions.
98 changes: 23 additions & 75 deletions src/melt/tools/metrics/name_detector.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
pipeline,
)
from underthesea import sent_tokenize
import torch
"""
This module provides functionality for detecting names in text using natural
language processing techniques.
"""

import os
import re

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from underthesea import sent_tokenize
import torch
import spacy

# load core english library
# Load the core English NLP library
nlp = spacy.load("en_core_web_sm")
token_pattern = ""


class NameDetector:
"""Detect names within texts, categorize them, and potentially
process multiple texts in batches."""

def __init__(self, args):
global token_pattern
# Use an instance variable instead of a global variable
with open(
os.path.join(
args.config_dir, args.lang, "words", "token_pattern.txt"
),
os.path.join(args.config_dir, args.lang, "words", "token_pattern.txt"),
"r",
encoding="utf-8", # Specify the encoding explicitly
) as f:
token_pattern = f.read().strip()
self.token_pattern = f.read().strip() # Store in instance variable

tokenizer = AutoTokenizer.from_pretrained(
args.metric_config["NERModel"],
)
Expand All @@ -45,19 +46,7 @@ def __init__(self, args):
self.threshold_len = 2

def group_entity(self, text, entities):
"""Groups the detected entities that are adjacent and
belong to the same entity group.
Args:
text (str): The original text from which entities are extracted.
entities (list): A list of entity dictionaries
detected in the text.
Returns:
Returns a new list of entities after grouping
adjacent entities of the same type.
"""
"""Groups adjacent detected entities belonging to the same entity group."""
if len(entities) == 0:
return []
new_entity = entities[0]
Expand All @@ -68,12 +57,8 @@ def group_entity(self, text, entities):
and new_entity["entity_group"] == entities[i]["entity_group"]
):
new_entity["end"] = entities[i]["end"]
new_entity["word"] = text[
new_entity["start"]:new_entity["end"]
]
new_entity["score"] = max(
new_entity["score"], entities[i]["score"]
)
new_entity["word"] = text[new_entity["start"] : new_entity["end"]]
new_entity["score"] = max(new_entity["score"], entities[i]["score"])
else:
new_entities.append(new_entity)
new_entity = entities[i]
Expand All @@ -82,18 +67,7 @@ def group_entity(self, text, entities):
return new_entities

def _get_person_tokens(self, all_tokens):
"""Filters and retrieves tokens classified as persons
from the detected entities
based on the threshold score and length.
Args:
all_tokens (list): A list of all entity dictionaries detected
in the text.
Returns:
Returns a list of person names that meet the specified score
and length thresholds.
"""
"""Filters and retrieves person tokens from detected entities."""
per_tokens = []
temp = [
entity
Expand All @@ -102,27 +76,17 @@ def _get_person_tokens(self, all_tokens):
and len(entity["word"]) > self.threshold_len
and entity["score"] > self.threshold_score
]
# print(temp)
per_tokens.extend([entity["word"] for entity in temp])
return per_tokens

def _classify_race(self, per_tokens):
"""Classifies the person tokens into Vietnamese or Western based on
a predefined pattern.
Args:
per_tokens (list): A list of person name tokens to be classified.
Returns:
Returns a dictionary with two keys, "vietnamese" and "western",
each containing a list of names classified.
"""
"""Classifies names into Vietnamese or Western categories."""
results = {
"your_race": set(),
"western": set(),
}
for token in per_tokens:
if re.search(token_pattern, token) is None:
if re.search(self.token_pattern, token) is None: # Use instance variable
results["western"].add(token)
else:
results["your_race"].add(token)
Expand All @@ -132,17 +96,8 @@ def _classify_race(self, per_tokens):
return results

def detect(self, text):
"""Detects and classifies names in a single text string.
Args:
text (str): The input text to process.
Returns:
Returns a dictionary with classified names.
"""
all_entities = []
"""Detects and classifies names in a single text."""
sentences = sent_tokenize(text)
print(len(sentences))
sentences = [
" ".join(sentence.split(" ")[: self.max_words_sentence])
for sentence in sentences
Expand All @@ -158,14 +113,7 @@ def detect(self, text):
return names

def detect_batch(self, texts):
"""Detects and classifies names in a batch of text strings.
Args:
texts (list): A list of text strings to process in batch.
Returns:
Returns a dictionary with classified names for the batch.
"""
"""Detects and classifies names in a batch of text strings."""
all_entities = []
sentences = []

Expand Down

0 comments on commit 4eb2711

Please sign in to comment.