-
Notifications
You must be signed in to change notification settings - Fork 1
/
crosslingual_topic_model.py
81 lines (56 loc) · 2.57 KB
/
crosslingual_topic_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import yaml
import scipy.io
from runners.Runner import Runner
import argparse
from utils.data import file_utils
from utils.data.TextData import DatasetHandler
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model')
parser.add_argument('--dataset')
parser.add_argument('--weight_MI', type=float)
parser.add_argument('--num_topic', type=int, default=50)
args = parser.parse_args()
return args
def export_beta(beta, vocab, output_prefix, lang):
num_top_word = 15
topic_str_list = file_utils.print_topic_words(beta, vocab, num_top_word=num_top_word)
file_utils.save_text(topic_str_list, path=f'{output_prefix}_T{num_top_word}_{lang}')
return topic_str_list
def main():
args = parse_args()
args = file_utils.update_args(args, f'./configs/model/{args.model}.yaml')
args = file_utils.update_args(args, f'./configs/dataset/{args.dataset}.yaml')
output_prefix = f'output/{args.dataset}/{args.model}_K{args.num_topic}'
file_utils.make_dir(os.path.dirname(output_prefix))
print('\n' + yaml.dump(vars(args), default_flow_style=False))
dataset_handler = DatasetHandler(args.dataset, args.batch_size, args.lang1, args.lang2, args.dict_path)
args.vocab_size_en = len(dataset_handler.vocab_en)
args.vocab_size_cn = len(dataset_handler.vocab_cn)
args.pretrain_word_embeddings_en = dataset_handler.pretrain_word_embeddings_en
args.pretrain_word_embeddings_cn = dataset_handler.pretrain_word_embeddings_cn
args.vocab_en = dataset_handler.vocab_en
args.vocab_cn = dataset_handler.vocab_cn
params_list = list()
params_list.append(dataset_handler.trans_matrix_en)
runner = Runner(args, params_list)
beta_en, beta_cn = runner.train(dataset_handler.train_loader)
topic_str_list_en = export_beta(beta_en, dataset_handler.vocab_en, output_prefix, lang='en')
topic_str_list_cn = export_beta(beta_cn, dataset_handler.vocab_cn, output_prefix, lang='cn')
for i in range(len(topic_str_list_en)):
print(topic_str_list_en[i])
print(topic_str_list_cn[i])
train_theta_en, train_theta_cn = runner.test(dataset_handler.train_loader.dataset)
test_theta_en, test_theta_cn = runner.test(dataset_handler.test_loader.dataset)
rst_dict = {
'beta_en': beta_en,
'beta_cn': beta_cn,
'train_theta_en': train_theta_en,
'train_theta_cn': train_theta_cn,
'test_theta_en': test_theta_en,
'test_theta_cn': test_theta_cn,
}
scipy.io.savemat(f'{output_prefix}_rst.mat', rst_dict)
if __name__ == '__main__':
main()