Skip to content

Commit

Permalink
Update name_detector.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtrung23 authored Sep 19, 2024
1 parent f87645a commit 4f03d8e
Showing 1 changed file with 67 additions and 36 deletions.
103 changes: 67 additions & 36 deletions src/melt/tools/metrics/name_detector.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,33 @@
"""
This module provides functionality for detecting names in text using natural
language processing techniques.
"""
"name_detector"
import os
import re
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
pipeline,
)
from underthesea import sent_tokenize
import torch
import spacy

try:
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
except ImportError:
print("The 'transformers' library is not installed. Please pip install transformers'.")

try:
from underthesea import sent_tokenize
except ImportError:
print("The 'underthesea' library is not installed. Please'pip install underthesea'.")

try:
import spacy
except ImportError:
print("The 'spacy' library is not installed. Please 'pip install spacy'.")

# Load the core English NLP library
# load core english library
nlp = spacy.load("en_core_web_sm")


class NameDetector:
"""Detect names within texts, categorize them, and potentially
process multiple texts in batches."""

token_pattern = "" # Renamed from TOKEN_PATTERN to token_pattern

def __init__(self, args):
# Use an instance variable instead of a global variable
with open(
os.path.join(args.config_dir, args.lang, "words", "token_pattern.txt"),
os.path.join(
args.config_dir, args.lang, "words", "token_pattern.txt"
),
"r",
encoding="utf-8", # Specify the encoding explicitly
encoding="utf-8"
) as f:
self.token_pattern = f.read().strip() # Store in instance variable

self.token_pattern = f.read().strip() # Updated attribute name here as well
tokenizer = AutoTokenizer.from_pretrained(
args.metric_config["NERModel"],
)
Expand All @@ -56,7 +46,19 @@ def __init__(self, args):
self.threshold_len = 2

def group_entity(self, text, entities):
"""Groups adjacent detected entities belonging to the same entity group."""
"""Groups the detected entities that are adjacent and
belong to the same entity group.
Args:
text (str): The original text from which entities are extracted.
entities (list): A list of entity dictionaries
detected in the text.
Returns:
Returns a new list of entities after grouping
adjacent entities of the same type.
"""
if len(entities) == 0:
return []
new_entity = entities[0]
Expand All @@ -67,8 +69,12 @@ def group_entity(self, text, entities):
and new_entity["entity_group"] == entities[i]["entity_group"]
):
new_entity["end"] = entities[i]["end"]
new_entity["word"] = text[new_entity["start"] : new_entity["end"]]
new_entity["score"] = max(new_entity["score"], entities[i]["score"])
new_entity["word"] = text[
new_entity["start"]:new_entity["end"]
]
new_entity["score"] = max(
new_entity["score"], entities[i]["score"]
)
else:
new_entities.append(new_entity)
new_entity = entities[i]
Expand All @@ -77,7 +83,8 @@ def group_entity(self, text, entities):
return new_entities

def _get_person_tokens(self, all_tokens):
"""Filters and retrieves person tokens from detected entities."""
"""Filters and retrieves tokens classified as persons
from the detected entities."""
per_tokens = []
temp = [
entity
Expand All @@ -90,13 +97,22 @@ def _get_person_tokens(self, all_tokens):
return per_tokens

def _classify_race(self, per_tokens):
"""Classifies names into Vietnamese or Western categories."""
"""Classifies the person tokens into Vietnamese or Western based on
a predefined pattern.
Args:
per_tokens (list): A list of person name tokens to be classified.
Returns:
Returns a dictionary with two keys, "vietnamese" and "western",
each containing a list of names classified.
"""
results = {
"your_race": set(),
"western": set(),
}
for token in per_tokens:
if re.search(self.token_pattern, token) is None: # Use instance variable
if re.search(self.token_pattern, token) is None: # Updated usage here
results["western"].add(token)
else:
results["your_race"].add(token)
Expand All @@ -106,8 +122,16 @@ def _classify_race(self, per_tokens):
return results

def detect(self, text):
"""Detects and classifies names in a single text."""
"""Detects and classifies names in a single text string.
Args:
text (str): The input text to process.
Returns:
Returns a dictionary with classified names.
"""
sentences = sent_tokenize(text)
print(len(sentences))
sentences = [
" ".join(sentence.split(" ")[: self.max_words_sentence])
for sentence in sentences
Expand All @@ -123,20 +147,27 @@ def detect(self, text):
return names

def detect_batch(self, texts):
"""Detects and classifies names in a batch of text strings."""
all_entities = []
"""Detects and classifies names in a batch of text strings.
Args:
texts (list): A list of text strings to process in batch.
Returns:
Returns a dictionary with classified names for the batch.
"""
sentences = []

for text in texts:
doc = nlp(text)
sentences = [sent.text for sent in doc.sents]
sentences.extend([sent.text for sent in doc.sents])

sentences = [
" ".join(sentence.split(" ")[: self.max_words_sentence])
for sentence in sentences
]
entities_lst = self.token_classifier(sentences, batch_size=128)

all_entities = []
for sentence, entities in zip(sentences, entities_lst):
all_entities += self.group_entity(sentence, entities)

Expand Down

0 comments on commit 4f03d8e

Please sign in to comment.