forked from jiegzhan/multi-class-text-classification-cnn-rnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
136 lines (113 loc) · 4.48 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import sys
import json
import shutil
import pickle
import logging
import data_helper
import numpy as np
import pandas as pd
import tensorflow as tf
from text_cnn_rnn import TextCNNRNN
logging.getLogger().setLevel(logging.INFO)
def load_trained_params(trained_dir):
params = json.loads(open(trained_dir + 'trained_parameters.json').read())
words_index = json.loads(open(trained_dir + 'words_index.json').read())
labels = json.loads(open(trained_dir + 'labels.json').read())
with open(trained_dir + 'embeddings.pickle', 'rb') as input_file:
fetched_embedding = pickle.load(input_file)
embedding_mat = np.array(fetched_embedding, dtype = np.float32)
return params, words_index, labels, embedding_mat
def load_test_data(test_file, labels):
df = pd.read_csv(test_file, sep='|')
select = ['Descript']
df = df.dropna(axis=0, how='any', subset=select)
test_examples = df[select[0]].apply(lambda x: data_helper.clean_str(x).split(' ')).tolist()
num_labels = len(labels)
one_hot = np.zeros((num_labels, num_labels), int)
np.fill_diagonal(one_hot, 1)
label_dict = dict(zip(labels, one_hot))
y_ = None
if 'Category' in df.columns:
select.append('Category')
y_ = df[select[1]].apply(lambda x: label_dict[x]).tolist()
not_select = list(set(df.columns) - set(select))
df = df.drop(not_select, axis=1)
return test_examples, y_, df
def map_word_to_index(examples, words_index):
x_ = []
for example in examples:
temp = []
for word in example:
if word in words_index:
temp.append(words_index[word])
else:
temp.append(0)
x_.append(temp)
return x_
def predict_unseen_data():
trained_dir = sys.argv[1]
if not trained_dir.endswith('/'):
trained_dir += '/'
test_file = sys.argv[2]
params, words_index, labels, embedding_mat = load_trained_params(trained_dir)
x_, y_, df = load_test_data(test_file, labels)
x_ = data_helper.pad_sentences(x_, forced_sequence_length=params['sequence_length'])
x_ = map_word_to_index(x_, words_index)
x_test, y_test = np.asarray(x_), None
if y_ is not None:
y_test = np.asarray(y_)
timestamp = trained_dir.split('/')[-2].split('_')[-1]
predicted_dir = './predicted_results_' + timestamp + '/'
if os.path.exists(predicted_dir):
shutil.rmtree(predicted_dir)
os.makedirs(predicted_dir)
with tf.Graph().as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
cnn_rnn = TextCNNRNN(
embedding_mat = embedding_mat,
non_static = params['non_static'],
hidden_unit = params['hidden_unit'],
sequence_length = len(x_test[0]),
max_pool_size = params['max_pool_size'],
filter_sizes = map(int, params['filter_sizes'].split(",")),
num_filters = params['num_filters'],
num_classes = len(labels),
embedding_size = params['embedding_dim'],
l2_reg_lambda = params['l2_reg_lambda'])
def real_len(batches):
return [np.ceil(np.argmin(batch + [0]) * 1.0 / params['max_pool_size']) for batch in batches]
def predict_step(x_batch):
feed_dict = {
cnn_rnn.input_x: x_batch,
cnn_rnn.dropout_keep_prob: 1.0,
cnn_rnn.batch_size: len(x_batch),
cnn_rnn.pad: np.zeros([len(x_batch), 1, params['embedding_dim'], 1]),
cnn_rnn.real_len: real_len(x_batch),
}
predictions = sess.run([cnn_rnn.predictions], feed_dict)
return predictions
checkpoint_file = trained_dir + 'best_model.ckpt'
saver = tf.train.Saver(tf.all_variables())
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file[:-5]))
saver.restore(sess, checkpoint_file)
logging.critical('{} has been loaded'.format(checkpoint_file))
batches = data_helper.batch_iter(list(x_test), params['batch_size'], 1, shuffle=False)
predictions, predict_labels = [], []
for x_batch in batches:
batch_predictions = predict_step(x_batch)[0]
for batch_prediction in batch_predictions:
predictions.append(batch_prediction)
predict_labels.append(labels[batch_prediction])
df['PREDICTED'] = predict_labels
columns = sorted(df.columns, reverse=True)
df.to_csv(predicted_dir + 'predictions_all.csv', index=False, columns=columns, sep='|')
if y_test is not None:
y_test = np.array(np.argmax(y_test, axis=1))
accuracy = sum(np.array(predictions) == y_test) / float(len(y_test))
logging.critical('The prediction accuracy is: {}'.format(accuracy))
logging.critical('Prediction is complete, all files have been saved: {}'.format(predicted_dir))
if __name__ == '__main__':
predict_unseen_data()