-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare.py
70 lines (53 loc) · 1.58 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import math
import random
from os import path
import nltk
import emoji
print('Downloading PUNKT model...')
nltk.download('punkt')
print('Normalizing and splitting data...')
color_emojis = [
'🏻',
'🏼',
'🏽',
'🏾',
'🏿'
]
def normalize_emojis(text):
return ''.join([c if c not in emoji.UNICODE_EMOJI or c in color_emojis else ' ' + c for c in text])
def normalize_characters(text):
return (text
.replace('^^', ' ^^ ')
.replace('„', '"')
.replace('“', '"')
.replace('..', ' .. ')
.replace('*', ' * ')
.replace(' ', ' '))
def normalize(text):
return ' '.join(nltk.word_tokenize(normalize_characters(normalize_emojis(text)), language='german'))
def write_data(data, name):
with open(path.join('chat-data', name + '.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(data))
src_train = []
tgt_train = []
with open('export/chat.txt', 'r', encoding='utf-8') as f:
previous_line = normalize(next(f))
for line in f:
normalized_line = normalize(line)
src_train.append(previous_line.casefold())
tgt_train.append(normalized_line)
previous_line = normalized_line
i_max = len(src_train)
src_val = []
tgt_val = []
for _ in range(math.floor(i_max * 0.1)):
i = random.randint(0, i_max - 1)
src_val.append(src_train[i])
del src_train[i]
tgt_val.append(tgt_train[i])
del tgt_train[i]
i_max -= 1
write_data(src_train, 'src-train')
write_data(src_val, 'src-val')
write_data(tgt_train, 'tgt-train')
write_data(tgt_val, 'tgt-val')