-
Notifications
You must be signed in to change notification settings - Fork 2
/
tokenizer.py
279 lines (226 loc) · 7.12 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from __future__ import annotations
from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass
from typing import (
Callable,
List,
Tuple,
Union
)
from os import PathLike
from data_loaders import JSONLoader
from utils import save_json
from functools import wraps
OOV = '<OOV>'
PHI = '_'
SOS = '<SOS>'
EOS = '<EOS>'
PAD = '<PAD>'
def check_token(token: str) -> Callable:
"""To check if a token exists or not
Args:
token ([type]): the token to be checked
"""
def decorator(func):
@wraps(func)
def wrapper(obj, token=token):
if token in obj._token_to_id:
return obj._token_to_id[token]
return func(obj, token)
return wrapper
return decorator
@dataclass
class SpecialTokens:
_oov: Tuple[str, int] = (None, None)
_pad: Tuple[str, int] = (None, None)
_sos: Tuple[str, int] = (None, None)
_eos: Tuple[str, int] = (None, None)
_phi: Tuple[str, int] = (None, None)
@property
def oov_id(self):
return self._oov[1]
@property
def oov_token(self):
return self._oov[0]
@property
def pad_id(self):
return self._pad[1]
@property
def pad_token(self):
return self._pad[0]
@property
def sos_id(self):
return self._sos[1]
@property
def sos_token(self):
return self._sos[0]
@property
def eos_id(self):
return self._eos[1]
@property
def eos_token(self):
return self._eos[0]
@property
def mask_id(self):
return self._mask[1]
@property
def mask_token(self):
return self._mask[0]
@property
def phi_id(self):
return self._phi[1]
@property
def phi_token(self):
return self._phi[0]
class ITokenizer(ABC):
@abstractmethod
def ids2tokens(self):
pass
@abstractmethod
def tokens2ids(self):
pass
@abstractmethod
def set_tokenizer(self):
pass
@abstractmethod
def save_tokenizer(self):
pass
@abstractmethod
def load_tokenizer(self):
pass
@abstractmethod
def add_token(self):
pass
@abstractmethod
def preprocess_tokens(self):
pass
@abstractmethod
def batch_tokenizer(self):
pass
@abstractproperty
def vocab_size(self):
pass
@abstractmethod
def get_tokens(self):
pass
class BaseTokenizer(ITokenizer):
_oov_key = 'oov'
_sos_key = 'sos'
_eos_key = 'eos'
_pad_key = 'pad'
_phi_key = 'phi'
_token_to_id_key = 'token_to_id'
_special_tokens_key = 'special_tokens'
def __init__(self) -> None:
super().__init__()
self._token_to_id = dict()
self._id_to_token = dict()
self.special_tokens = SpecialTokens()
@property
def vocab_size(self):
return len(self._token_to_id)
def add_token(self, token: str):
token_id = self.vocab_size
self._token_to_id[token] = token_id
self._id_to_token[token_id] = token
return token_id
@check_token(OOV)
def add_oov_token(self, token=OOV) -> ITokenizer:
token_id = self.add_token(token)
self.special_tokens._oov = (token, token_id)
return self
@check_token(PAD)
def add_pad_token(self, token=PAD) -> ITokenizer:
token_id = self.add_token(token)
self.special_tokens._pad = (token, token_id)
return self
@check_token(SOS)
def add_sos_token(self, token=SOS) -> ITokenizer:
token_id = self.add_token(token)
self.special_tokens._sos = (token, token_id)
return self
@check_token(EOS)
def add_eos_token(self, token=EOS) -> ITokenizer:
token_id = self.add_token(token)
self.special_tokens._eos = (token, token_id)
return self
@check_token(PHI)
def add_phi_token(self, token=PHI) -> ITokenizer:
token_id = self.add_token(token)
self.special_tokens._phi = (token, token_id)
return self
def _reset_id_to_token(self) -> None:
self._id_to_token = dict(zip(
self._token_to_id.values(),
self._token_to_id.keys()
))
def __set_special_tokens_dict(self, data: dict) -> None:
if self._oov_key in data:
self.special_tokens._oov = tuple(data[self._oov_key])
if self._pad_key in data:
self.special_tokens._pad = tuple(data[self._pad_key])
if self._sos_key in data:
self.special_tokens._sos = tuple(data[self._sos_key])
if self._eos_key in data:
self.special_tokens._eos = tuple(data[self._eos_key])
if self._phi_key in data:
self.special_tokens._phi = tuple(data[self._phi_key])
def __get_special_tokens_dict(self) -> dict:
data = {}
if self.special_tokens.oov_id is not None:
data[self._oov_key] = list(self.special_tokens._oov)
if self.special_tokens.pad_id is not None:
data[self._pad_key] = list(self.special_tokens._pad)
if self.special_tokens.sos_id is not None:
data[self._sos_key] = list(self.special_tokens._sos)
if self.special_tokens.eos_id is not None:
data[self._eos_key] = list(self.special_tokens._eos)
if self.special_tokens.phi_id is not None:
data[self._phi_key] = list(self.special_tokens._phi)
return data
def load_tokenizer(
self,
tokenizer_path: Union[str, PathLike],
*args,
**kwargs
) -> ITokenizer:
data = JSONLoader(tokenizer_path).load()
self._token_to_id = data[self._token_to_id_key]
self.__set_special_tokens_dict(data[self._special_tokens_key])
self._reset_id_to_token()
return self
def set_tokenizer(self, data: List[str], *args, **kwargs) -> ITokenizer:
all_tokens = self.get_tokens(data)
_ = list(map(self.add_token, all_tokens))
self._reset_id_to_token()
return self
def save_tokenizer(
self,
save_path: Union[str, PathLike],
*args,
**kwargs
) -> None:
data = {
self._token_to_id_key: self._token_to_id,
self._special_tokens_key: self.__get_special_tokens_dict()
}
save_json(save_path, data)
def ids2tokens(self, ids: List[str]) -> List[str]:
return list(map(lambda x: self._id_to_token[x], ids))
def tokens2ids(self, sentence: str) -> List[int]:
sentence = self.preprocess_tokens(sentence)
return list(map(
lambda x: self._token_to_id.get(x, self.special_tokens.oov_id),
sentence)
)
def batch_tokenizer(self, data: List[str]) -> list:
return list(map(self.tokens2ids, data))
def batch_detokenizer(self, data: List[int]) -> list:
return list(map(self.ids2tokens, data))
class CharTokenizer(BaseTokenizer):
def __init__(self) -> None:
super().__init__()
def get_tokens(self, data: List[str]):
return set(''.join(data))
def preprocess_tokens(self, sentence: str) -> List[str]:
return list(sentence)