-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
191 lines (160 loc) · 6.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from scipy.sparse import load_npz
import numpy as np
import csv
import os
def _load_csv(path):
# A helper function to load the csv file.
if not os.path.exists(path):
raise Exception("The specified path {} does not exist.".format(path))
# Initialize the data.
data = {
"user_id": [],
"question_id": [],
"is_correct": []
}
# Iterate over the row to fill in the data.
with open(path, "r") as csv_file:
reader = csv.reader(csv_file)
for row in reader:
try:
data["question_id"].append(int(row[0]))
data["user_id"].append(int(row[1]))
data["is_correct"].append(int(row[2]))
except ValueError:
# Pass first row.
pass
except IndexError:
# is_correct might not be available.
pass
return data
def load_train_sparse(root_dir="/data"):
""" Load the training data as a spare matrix representation.
:param root_dir: str
:return: 2D sparse matrix
"""
path = os.path.join(root_dir, "train_sparse.npz")
if not os.path.exists(path):
raise Exception("The specified path {} "
"does not exist.".format(os.path.abspath(path)))
matrix = load_npz(path)
return matrix
def load_train_csv(root_dir="/data"):
""" Load the training data as a dictionary.
:param root_dir: str
:return: A dictionary {user_id: list, question_id: list, is_correct: list}
WHERE
user_id: a list of user id.
question_id: a list of question id.
is_correct: a list of binary value indicating the correctness of
(user_id, question_id) pair.
"""
path = os.path.join(root_dir, "train_data.csv")
return _load_csv(path)
def load_valid_csv(root_dir="/data"):
""" Load the validation data as a dictionary.
:param root_dir: str
:return: A dictionary {user_id: list, question_id: list, is_correct: list}
WHERE
user_id: a list of user id.
question_id: a list of question id.
is_correct: a list of binary value indicating the correctness of
(user_id, question_id) pair.
"""
path = os.path.join(root_dir, "valid_data.csv")
return _load_csv(path)
def load_public_test_csv(root_dir="/data"):
""" Load the test data as a dictionary.
:param root_dir: str
:return: A dictionary {user_id: list, question_id: list, is_correct: list}
WHERE
user_id: a list of user id.
question_id: a list of question id.
is_correct: a list of binary value indicating the correctness of
(user_id, question_id) pair.
"""
path = os.path.join(root_dir, "test_data.csv")
return _load_csv(path)
def load_private_test_csv(root_dir="/data"):
""" Load the private test data as a dictionary.
:param root_dir: str
:return: A dictionary {user_id: list, question_id: list, is_correct: list}
WHERE
user_id: a list of user id.
question_id: a list of question id.
is_correct: an empty list.
"""
path = os.path.join(root_dir, "private_test_data.csv")
return _load_csv(path)
def save_private_test_csv(data, file_name="private_test_result.csv"):
""" Save the private test data as a csv file.
This should be your submission file to Kaggle.
:param data: A dictionary {user_id: list, question_id: list, is_correct: list}
WHERE
user_id: a list of user id.
question_id: a list of question id.
is_correct: a list of binary value indicating the correctness of
(user_id, question_id) pair.
:param file_name: str
:return: None
"""
if not isinstance(data, dict):
raise Exception("Data must be a dictionary.")
cur_id = 1
valid_id = ["0", "1"]
with open(file_name, "w") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["id", "is_correct"])
for i in range(len(data["user_id"])):
if str(int(data["is_correct"][i])) not in valid_id:
raise Exception("Your data['is_correct'] is not in a valid format.")
writer.writerow([str(cur_id), str(int(data["is_correct"][i]))])
cur_id += 1
return
def evaluate(data, predictions, threshold=0.5):
""" Return the accuracy of the predictions given the data.
:param data: A dictionary {user_id: list, question_id: list, is_correct: list}
:param predictions: list
:param threshold: float
:return: float
"""
if len(data["is_correct"]) != len(predictions):
raise Exception("Mismatch of dimensions between data and prediction.")
if isinstance(predictions, list):
predictions = np.array(predictions).astype(np.float64)
return (np.sum((predictions >= threshold) == data["is_correct"])
/ float(len(data["is_correct"])))
def sparse_matrix_evaluate(data, matrix, threshold=0.5):
""" Given the sparse matrix represent, return the accuracy of the prediction on data.
:param data: A dictionary {user_id: list, question_id: list, is_correct: list}
:param matrix: 2D matrix
:param threshold: float
:return: float
"""
total_prediction = 0
total_accurate = 0
for i in range(len(data["is_correct"])):
cur_user_id = data["user_id"][i]
cur_question_id = data["question_id"][i]
if matrix[cur_user_id, cur_question_id] >= threshold and data["is_correct"][i]:
total_accurate += 1
if matrix[cur_user_id, cur_question_id] < threshold and not data["is_correct"][i]:
total_accurate += 1
total_prediction += 1
return total_accurate / float(total_prediction)
def sparse_matrix_predictions(data, matrix, threshold=0.5):
""" Given the sparse matrix represent, return the predictions.
This function can be used for submitting Kaggle competition.
:param data: A dictionary {user_id: list, question_id: list, is_correct: list}
:param matrix: 2D matrix
:param threshold: float
:return: list
"""
predictions = []
for i in range(len(data["user_id"])):
cur_user_id = data["user_id"][i]
cur_question_id = data["question_id"][i]
if matrix[cur_user_id, cur_question_id] >= threshold:
predictions.append(1.)
else:
predictions.append(0.)
return predictions