-
Notifications
You must be signed in to change notification settings - Fork 12
/
run.py
47 lines (38 loc) · 1.48 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import model
import dataset
import numpy as np
def word_to_hash(chars, word):
word = word.lower()
chars = list(chars)
hashed = [chars.index(char) for char in word]
while(len(hashed) < 10):
hashed.append(-1)
return np.ndarray((1,10), buffer=np.array(hashed), dtype=int)
def get_predicted_language(probs):
languages = ["English", "Spanish", "Finnish", "Dutch", "Polish"]
max_index = 0
max_val = -float("inf")
for index in range(len(probs)):
if(probs[index] > max_val):
max_val = probs[index]
max_index = index
return (max_val, languages[max_index])
def main():
language_classifier = model.LanguageClassificationModel()
data = dataset.LanguageClassificationDataset(language_classifier)
chars = data.chars
language_classifier.train(data)
test_predicted_probs, test_predicted, test_correct = data._predict('test')
test_accuracy = np.mean(test_predicted == test_correct)
print("test set accuracy is: {:%}\n".format(test_accuracy))
while True:
word = input("Enter a word(press q to quit): ")
if(word == "q"):
break
xs = data._encode(word_to_hash(chars, word), None, True)
result = language_classifier.run(xs)
probs = data._softmax(result.data)
max_prob, pred_lang = get_predicted_language(probs[0])
print("predicted language is: {}, with a confidence of {:%}\n".format(pred_lang, max_prob))
if __name__ == "__main__":
main()