diff --git a/src/melt/tools/metrics/name_detector.py b/src/melt/tools/metrics/name_detector.py index 1ee59c7..b8b6339 100644 --- a/src/melt/tools/metrics/name_detector.py +++ b/src/melt/tools/metrics/name_detector.py @@ -1,43 +1,33 @@ -""" -This module provides functionality for detecting names in text using natural -language processing techniques. -""" +"name_detector" import os import re +from transformers import ( + AutoTokenizer, + AutoModelForTokenClassification, + pipeline, +) +from underthesea import sent_tokenize import torch +import spacy -try: - from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline -except ImportError: - print("The 'transformers' library is not installed. Please pip install transformers'.") - -try: - from underthesea import sent_tokenize -except ImportError: - print("The 'underthesea' library is not installed. Please'pip install underthesea'.") - -try: - import spacy -except ImportError: - print("The 'spacy' library is not installed. Please 'pip install spacy'.") - -# Load the core English NLP library +# load core english library nlp = spacy.load("en_core_web_sm") - class NameDetector: """Detect names within texts, categorize them, and potentially process multiple texts in batches.""" + token_pattern = "" # Renamed from TOKEN_PATTERN to token_pattern + def __init__(self, args): - # Use an instance variable instead of a global variable with open( - os.path.join(args.config_dir, args.lang, "words", "token_pattern.txt"), + os.path.join( + args.config_dir, args.lang, "words", "token_pattern.txt" + ), "r", - encoding="utf-8", # Specify the encoding explicitly + encoding="utf-8" ) as f: - self.token_pattern = f.read().strip() # Store in instance variable - + self.token_pattern = f.read().strip() # Updated attribute name here as well tokenizer = AutoTokenizer.from_pretrained( args.metric_config["NERModel"], ) @@ -56,7 +46,19 @@ def __init__(self, args): self.threshold_len = 2 def group_entity(self, text, entities): - """Groups adjacent detected entities belonging to the same entity group.""" + """Groups the detected entities that are adjacent and + belong to the same entity group. + + Args: + text (str): The original text from which entities are extracted. + + entities (list): A list of entity dictionaries + detected in the text. + + Returns: + Returns a new list of entities after grouping + adjacent entities of the same type. + """ if len(entities) == 0: return [] new_entity = entities[0] @@ -67,8 +69,12 @@ def group_entity(self, text, entities): and new_entity["entity_group"] == entities[i]["entity_group"] ): new_entity["end"] = entities[i]["end"] - new_entity["word"] = text[new_entity["start"] : new_entity["end"]] - new_entity["score"] = max(new_entity["score"], entities[i]["score"]) + new_entity["word"] = text[ + new_entity["start"]:new_entity["end"] + ] + new_entity["score"] = max( + new_entity["score"], entities[i]["score"] + ) else: new_entities.append(new_entity) new_entity = entities[i] @@ -77,7 +83,8 @@ def group_entity(self, text, entities): return new_entities def _get_person_tokens(self, all_tokens): - """Filters and retrieves person tokens from detected entities.""" + """Filters and retrieves tokens classified as persons + from the detected entities.""" per_tokens = [] temp = [ entity @@ -90,13 +97,22 @@ def _get_person_tokens(self, all_tokens): return per_tokens def _classify_race(self, per_tokens): - """Classifies names into Vietnamese or Western categories.""" + """Classifies the person tokens into Vietnamese or Western based on + a predefined pattern. + + Args: + per_tokens (list): A list of person name tokens to be classified. + + Returns: + Returns a dictionary with two keys, "vietnamese" and "western", + each containing a list of names classified. + """ results = { "your_race": set(), "western": set(), } for token in per_tokens: - if re.search(self.token_pattern, token) is None: # Use instance variable + if re.search(self.token_pattern, token) is None: # Updated usage here results["western"].add(token) else: results["your_race"].add(token) @@ -106,8 +122,16 @@ def _classify_race(self, per_tokens): return results def detect(self, text): - """Detects and classifies names in a single text.""" + """Detects and classifies names in a single text string. + + Args: + text (str): The input text to process. + + Returns: + Returns a dictionary with classified names. + """ sentences = sent_tokenize(text) + print(len(sentences)) sentences = [ " ".join(sentence.split(" ")[: self.max_words_sentence]) for sentence in sentences @@ -123,13 +147,19 @@ def detect(self, text): return names def detect_batch(self, texts): - """Detects and classifies names in a batch of text strings.""" - all_entities = [] + """Detects and classifies names in a batch of text strings. + + Args: + texts (list): A list of text strings to process in batch. + + Returns: + Returns a dictionary with classified names for the batch. + """ sentences = [] for text in texts: doc = nlp(text) - sentences = [sent.text for sent in doc.sents] + sentences.extend([sent.text for sent in doc.sents]) sentences = [ " ".join(sentence.split(" ")[: self.max_words_sentence]) @@ -137,6 +167,7 @@ def detect_batch(self, texts): ] entities_lst = self.token_classifier(sentences, batch_size=128) + all_entities = [] for sentence, entities in zip(sentences, entities_lst): all_entities += self.group_entity(sentence, entities)