-
Notifications
You must be signed in to change notification settings - Fork 64
/
test.py
50 lines (40 loc) · 1.94 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# -*- coding: utf-8 -*-
# file: test.py
# time: 29/01/2022
# author: yangheng <yangheng@m.scnu.edu.cn>
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import os
import findfile
from findfile import rm_files
from pyabsa import check_package_version, generate_inference_set_for_apc, convert_apc_set_to_atepc_set
check_package_version(min_version='v2.0.0')
# from pyabsa import APCCheckpointManager
# classifier = APCCheckpointManager.get_sentiment_classifier('fast')
# label_map = {0: 'Negative', 1: 'Neutral', 2: 'Positive', '0': 'Negative', '1': 'Neutral', '2': 'Positive'}
# def read_csv_by_pandas(path):
# import pandas as pd
# df = pd.read_csv(path)
# with open(path+'.dat', 'w', encoding='utf-8') as f:
# for index, row in df.iterrows():
# if row['aspect'].strip() and row['text'].strip() and row['aspect'] in row['text']:
# text = row['text'].replace('\r', '').replace('\n', '').strip()
# aspect = row['aspect'].replace('\r', '').replace('\n', '').strip()
# res = classifier.infer(text.replace(aspect, '[ASP]{}[ASP]'.format(aspect)))
# f.write(text.replace(aspect, '$T$')+'\n')
# f.write(aspect+'\n')
# f.write(label_map[res['sentiment'][0]]+'\n')
# # f.write(label_map[row['label']]+'\n')
# return df
# read_csv_by_pandas('datasets/apc_datasets/129.Kaggle/test.csv')
findfile.rm_files(os.getcwd(), '.ignore.atepc')
findfile.rm_files(os.getcwd(), '.atepc.ignore')
# batch conversion for all ABSA datasets
generate_inference_set_for_apc('apc_datasets')
convert_apc_set_to_atepc_set('apc_datasets')
# default set ATEPC augmented datasets invisible
for f in findfile.find_cwd_files('.ignore.atepc'):
os.rename(f, f.replace('.ignore.atepc', '.atepc.ignore'))
# remove train and valid inference set, as they are useless
rm_files(key=['train', 'inference'])
rm_files(key=['valid', 'inference'])