Skip to content

Commit

Permalink
update preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
csong27 committed Jul 27, 2020
1 parent 81de7c8 commit e83cbc4
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 15 deletions.
3 changes: 0 additions & 3 deletions .gitignore

This file was deleted.

6 changes: 5 additions & 1 deletion constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@

# processed MIMIC-III data
PROCESSED_DIR = f'{ICD_DATA_DIR}/processed'
if not os.path.exists(PROCESSED_DIR):
os.mkdir(PROCESSED_DIR)

# path for caching all processed training data
CACHE_PATH = f'{PROCESSED_DIR}/preloaded_train.npz'

# resources including vocabs, embeddings, ICD code domain knowledge, keywords etc
RESOURCES_DIR = f'{ICD_DATA_DIR}/resources'
VOCAB_PATH = f'{RESOURCES_DIR}/vocab_to_ix.pkl'
VOCAB_PATH = f'{RESOURCES_DIR}/vocab.txt'
VOCAB_DICT_PATH = f'{RESOURCES_DIR}/vocab_to_ix.pkl'
EMBEDDING_PATH = f'{RESOURCES_DIR}/icd_word_emb.pkl'
KEYWORDS_PATH = f'{RESOURCES_DIR}/mimic3_note_keywords.pkl'
ICD_CODE_HIERARCHY_PATH = f'{RESOURCES_DIR}/icd_code_hierarchy.txt'
Expand Down
14 changes: 8 additions & 6 deletions loaders/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import copy
import os
import pickle
from collections import defaultdict

import numpy as np
Expand All @@ -26,9 +25,13 @@
from utils.helper import log


def load_vocab(vocab_path=VOCAB_PATH):
with open(vocab_path, 'rb') as f:
vocab_to_ix = pickle.load(f)
def load_vocab():
vocab_to_ix = dict()
with open(VOCAB_PATH) as f:
i = 0
for line in f:
vocab_to_ix[line.strip()] = i
i += 1
return vocab_to_ix


Expand Down Expand Up @@ -95,14 +98,13 @@ def preload_data(train_notes, train_labels, codes_to_targets, max_note_len=2000,
x.append(xx)
mask.append(m)
row += 1
# print(row)

x = np.vstack(x).astype(int)
y = np.asarray(y)
mask = np.vstack(mask).astype(np.float32)

if save_path is not None:
log(f'Saving to {save_path}...')
log(f'Saving training data cache to {save_path}...')
np.savez(save_path, x, y, mask)

return x, y, mask
Expand Down
2 changes: 1 addition & 1 deletion modules/icd_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, word_emb, code_idx_matrix, code_idx_mask, adj_matrix, num_nei

graph_encoder = get_graph_encoder(graph_encoder)
if graph_encoder is not None:
log(f'Using {graph_encoder} for encoding ICD hierarchy...')
log(f'Using {graph_encoder.__name__} for encoding ICD hierarchy...')
self.graph_label_encoder = graph_encoder(self.embed_size, label_hidden_size, self.n_nodes)
self.feat_size += label_hidden_size
else:
Expand Down
177 changes: 177 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import math
import os
import pickle
import shutil

import numpy as np
import pandas
import tqdm

from constant import PROCESSED_DIR, VOCAB_DICT_PATH
from utils.helper import log
from utils.tokenizer import Tokenizer

parser = argparse.ArgumentParser(description='Extract and preprocess MIMIC-III patient notes')
parser.add_argument('--mimic_dir', default=None, type=str, required=True,
help='directory to MIMIC-III dataset, including NOTEEVENTS.csv and DIAGNOSES_ICD.csv')


def make_folder(folder):
if not os.path.exists(folder):
os.makedirs(folder)


def remove_folder(folder):
if os.path.exists(folder):
shutil.rmtree(folder)


def is_discharge_summary(note_category):
return 'discharge summary' in note_category.lower().strip()


def make_patient_dict(mimic_dir):
read_file = f'{mimic_dir}/NOTEEVENTS.csv'
log(f'Reading {read_file} ...')
df_notes = pandas.read_csv(read_file, low_memory=False, dtype=str)

read_file = f'{mimic_dir}/DIAGNOSES_ICD.csv'
log(f'Reading {read_file} ...')
df_icds = pandas.read_csv(read_file, low_memory=False, dtype=str)

all_notes = df_notes['TEXT']
all_note_types = df_notes['CATEGORY']
all_note_descriptions = df_notes['DESCRIPTION']

subject_ids_notes = df_notes['SUBJECT_ID']
hadm_ids_notes = df_notes['HADM_ID']

subject_ids_icd = df_icds['SUBJECT_ID']
hadm_ids_icd = df_icds['HADM_ID']
seq_nums_icd = df_icds['SEQ_NUM']
icd9_codes = df_icds['ICD9_CODE']
patient_dict = {(subject_id, hadm_id): [{}, {}] for subject_id, hadm_id in zip(subject_ids_notes, hadm_ids_notes)}

# staring with icd code labels and collecting only those subject_id,
# hadm_id pairs with at least one non-nan icd label
for (subject_id, hadm_id, seq_num, icd9_code) in zip(subject_ids_icd, hadm_ids_icd, seq_nums_icd, icd9_codes):
try: # there are cases where subject id, hadm id pairs are present in icd code data but not in noteevents data.
# checking for nan, will fail for string then go to except and put in patient dict
if not math.isnan(seq_num):
patient_dict[(subject_id, hadm_id)][1][seq_num] = icd9_code
except TypeError:
try:
patient_dict[(subject_id, hadm_id)][1][seq_num] = icd9_code
except KeyError: # if not in admissions data
pass

for (subject_id, hadm_id, note, note_type, note_description) in zip(subject_ids_notes, hadm_ids_notes, all_notes,
all_note_types, all_note_descriptions):
if is_discharge_summary(note_type):
if (note_type, note_description) in patient_dict[(subject_id, hadm_id)][0]:
patient_dict[(subject_id, hadm_id)][0][(note_type, note_description)].append(note)
else:
patient_dict[(subject_id, hadm_id)][0][(note_type, note_description)] = [note]

to_remove = []
for (subject_id, hadm_id) in patient_dict:
if len(patient_dict[(subject_id, hadm_id)][0]) == 0 or len(patient_dict[(subject_id, hadm_id)][1]) == 0:
to_remove.append((subject_id, hadm_id))
for key in to_remove:
patient_dict.pop(key)

log(f'Total number of (subject_id, hadm_id) with discharge summary, with at least 1 code: {len(patient_dict)}')
return patient_dict


def concat_and_write(list_of_notes, concatenated_file):
concatenated_text = ''.join(list_of_notes)
f = open(concatenated_file, 'w')
f.write(concatenated_text)
f.close()


def make_text_files(mimic_dir, save_dir):
patient_dict = make_patient_dict(mimic_dir)

text_save_dir = f'{save_dir}/text_files/'
make_folder(text_save_dir)
label_save_dir = f'{save_dir}/label_files/'
make_folder(label_save_dir)

total_txt_count = 0
for (subject_id, hadm_id) in tqdm.tqdm(patient_dict, desc='Extracting text files'):
icd9_dict = patient_dict[(subject_id, hadm_id)][1]

all_descriptions = []
for category, description in patient_dict[(subject_id, hadm_id)][0].keys():
notes = patient_dict[(subject_id, hadm_id)][0][(category, description)]
all_descriptions.extend(notes)

# writing description notes
text_save_path = f'{text_save_dir}/{subject_id}_{hadm_id}_notes.txt'
concat_and_write(all_descriptions, text_save_path)
# writing icd labels
label_save_path = f'{label_save_dir}/{subject_id}_{hadm_id}_labels.txt'
f = open(label_save_path, 'w')
for key in icd9_dict:
f.write('{}, {}\n'.format(key, icd9_dict[key]))
f.close()
total_txt_count += 1

log(f'Written {total_txt_count} text files to {save_dir}')


def preprocess_raw_text(save_dir):
text_save_dir = os.path.join(save_dir, 'text_files')
numpy_vectors_save_dir = os.path.join(save_dir, 'numpy_vectors')
remove_folder(numpy_vectors_save_dir)
make_folder(numpy_vectors_save_dir)
hadms = []
for filename in os.listdir(text_save_dir):
if ".txt" in filename:
hadm = filename.replace(".txt", "")
hadms.append(hadm)
log(f"Total number of text files in set: {len(hadms)}")

log(f'Loading vocab dict saved during from {VOCAB_DICT_PATH}')
with open(VOCAB_DICT_PATH, 'rb') as f:
vocab = pickle.load(f)
tokenizer = Tokenizer(vocab)

for hadm in tqdm.tqdm(hadms, desc='Generating processed texts'):
text = open(os.path.join(text_save_dir, str(hadm) + ".txt"), "r").read()
words = tokenizer.process(text)
vector = []
for word in words:
if word in vocab:
vector.append(vocab[word])
elif tokenizer.only_numerals(word) and (len(vector) == 0 or vector[-1] != vocab["<NUM>"]):
vector.append(vocab["<NUM>"])

mat = np.array(vector)
# saving word indices to file
write_file = os.path.join(numpy_vectors_save_dir, f"{hadm}.npy")
np.save(write_file, mat)


if __name__ == '__main__':
args = parser.parse_args()
make_text_files(args.mimic_dir, PROCESSED_DIR)
preprocess_raw_text(PROCESSED_DIR)
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
torch==1.3.1
numpy==1.16.4
torch==1.4.0
numpy==1.18.2
nltk==3.4.4
scikit-learn==0.21.2
matplotlib==3.1.1
gensim==3.7.3
tqdm==4.43.0
pandas==0.24.2
5 changes: 3 additions & 2 deletions train_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def eval_trained(eval_batch_size=16, max_note_len=2000, loss='bce', gpu='cuda:1'
graph_encoder=graph_encoder, eval_code_size=eval_code_size,
target_count=target_count if class_margin else None, C=C)

pretrained_model_path = f"{MODEL_DIR}/final_{model.name}"
pretrained_model_path = f"{MODEL_DIR}/{model.name}"
pretrained_dict = load_first_stage_model(pretrained_model_path, device)

model_dict = model.state_dict()
Expand All @@ -227,7 +227,6 @@ def eval_trained(eval_batch_size=16, max_note_len=2000, loss='bce', gpu='cuda:1'
log('Preloading data in memory...')
dev_x, dev_y, dev_masks = preload_data(dev_notes, dev_labels, codes_to_targets, max_note_len)
test_x, test_y, test_masks = preload_data(test_notes, test_labels, codes_to_targets, max_note_len)
log('Evaluating...')

def eval_wrapper(x, y, masks):
y_true = []
Expand All @@ -250,9 +249,11 @@ def eval_wrapper(x, y, masks):
y_true = np.vstack(y_true).astype(int)
return y_true, y_score

log('Evaluating on dev set...')
dev_true, dev_score = eval_wrapper(dev_x, dev_y, dev_masks)
log_eval_metrics(0, dev_score, dev_true, dev_freq_indices, dev_few_shot_indices, dev_zero_shot_indices)

log('Evaluating on test set...')
test_true, test_score = eval_wrapper(test_x, test_y, test_masks)
log_eval_metrics(0, test_score, test_true, test_freq_indices, test_few_shot_indices, test_zero_shot_indices)

Expand Down
69 changes: 69 additions & 0 deletions utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re


class Tokenizer(object):
def __init__(self, vocab: dict):
super(Tokenizer, self).__init__()
self.vocab = vocab
self.special_rep = 'NOUN'
self.unk = "<UNK>"
self.num = "<NUM>"
self.pat = re.compile(r'(\[\*\*[^\[]*\*\*\])')

def remove_special_token(self, sent):
return self.pat.sub(self.special_rep, sent)

@staticmethod
def tokenize(sent):
words = [s for s in re.split(r"\W+", sent) if s and not s.isspace()]
return words

def replace_unknowns_nums(self, words):
tokens = []
for word in words:
if self.special_rep.lower() == word:
continue

if word in self.vocab:
tokens.append(word)
else:
token = self.distinguish_unk_num(word)
if len(tokens) == 0 or tokens[-1] != token:
tokens.append(token)
return tokens

def distinguish_unk_num(self, word):
if self.only_numerals(word):
return self.num
else:
return self.unk

@staticmethod
def only_numerals(string):
try:
int(string)
return True
except ValueError:
return False

def process(self, input_text: str):
input_text = self.remove_special_token(input_text)
words_tokenized = self.tokenize(input_text)
words = [word.lower().strip() for word in words_tokenized]
words = self.replace_unknowns_nums(words)
return words

0 comments on commit e83cbc4

Please sign in to comment.