This is the code for the PRICAI paper Enhanced Cognitive Distortions Detection and Classification through Data Augmentation Techniques.
By Mohamad Rasmy, Caroline Sabty, Nourhan Sakr, and Alia El Bolock.
We present an enhanced approach for detecting and classifying cognitive distortions, building upon previous work by Elsharawi et al. (2024) that identified a CNN model using BERT embeddings as the most effective method. Our study explores additional embeddings from RoBERTa and GPT-2, implements fine-tuning of these models, and applies various data augmentation techniques to enhance the training dataset. The data augmentation techniques used in this study are:
-
Easy Data Augmentation (EDA): This technique includes Synonym Replacement (SR), where random words in a sentence are replaced with their synonyms, and Random Insertion (RI), where a synonym of a random word is inserted at a random position in the sentence.
-
Word Embedding Substitution (WE_Sub): This technique replaces words in a sentence with other words that have similar embeddings, providing greater flexibility and variation compared to synonym-based replacement.
-
Back-Translation (BT): This technique involves translating a sentence from the source language to one or more intermediate languages and then back to the original language, introducing lexical and syntactic variations while preserving the overall semantic meaning.
We design our classification approach to address three distinct tasks:
-
Binary Classification (Distortion Detection): Distinguishing between distorted and neutral data.
-
Multi-Class Classification (Distortion Classification): Classifying distorted samples into fourteen cognitive distortion types.
-
Semantic Grouping Classification: Grouping the fourteen distortion types into five semantic categories to address class imbalance and simplify the classification task.
Each script can be run independently with specific command-line arguments.
This script performs back translation using Opus MT models.
python back-translation.py --data_path path/to/your/input.csv
Arguments:
--data_path
: Path to the input CSV file containing the data to be augmented.
This script implements Easy Data Augmentation (EDA) techniques: Synonym Replacement (SR) and Random Insertion (RI).
python sr-ri-eda.py --alpha 0.4 --total_augments_per_sentence 4 --data_path path/to/your/input.csv
Arguments:
--alpha
: Percentage of words to be changed in a sentence (float between 0 and 1).--total_augments_per_sentence
: Total number of augmentations to generate per sentence.--data_path
: Path to the input CSV file containing the data to be augmented.
This script performs word substitution using BERT contextual embeddings.
python we-substitution-bert.py --alpha 0.3 --data_path path/to/your/input.csv
Arguments:
--alpha
: Percentage of words to be augmented in each sentence (float between 0 and 1).--data_path
: Path to the input CSV file containing the data to be augmented.
- The input CSV file should contain a column named 'sentence' with the text data to be augmented.
- Each script will output a new CSV file with the augmented data in the same directory as the script.
If you use this in your paper, please cite us:
@InProceedings{10.1007/978-981-96-0116-5_11,
author="Rasmy, Mohamad
and Sabty, Caroline
and Sakr, Nourhan
and El Bolock, Alia",
title="Enhanced Cognitive Distortions Detection and Classification Through Data Augmentation Techniques",
booktitle="PRICAI 2024: Trends in Artificial Intelligence",
year="2025",
publisher="Springer Nature Singapore",
pages="134--145"
}
Experimental training for cognitive distortion detection can be found here for an experimental training used in the paper. Available hyperparameters at the start of the file to change training setups:
classification
: Set to 2/5/14 for different classification strategies.val_pct
: Percentage of data for validation and test sets combined.test_pct
: Percentage of test data from theval_pct
. For example, if val_pct = 0.3 and test_pct = 2/3, then the test set would be (2/3)*0.3 = 0.2 and the actual validation set percentage would be the remaining (0.3-0.2) = 0.1.model_name
: Set the transformer model name string (e.g., "bert-base-uncased", "roberta-base", "openai-community/gpt2").is_CNN
: Boolean, iftrue
would use CNN architecture with transformermodel_name
embeddings, iffalse
would fine-tune the transformer modelmodel_name
.api_key
: String set to your wandb API key.