-
Notifications
You must be signed in to change notification settings - Fork 1
/
filters.py
287 lines (246 loc) · 9.24 KB
/
filters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""filters.py.
Last Update: June 12 2024
"""
import re
from typing import List, Union
import numpy as np
import spacy
from spacy.attrs import (DEP, ENT_ID, ENT_IOB, ENT_TYPE, IS_ALPHA, IS_ASCII,
IS_DIGIT, IS_LOWER, IS_PUNCT, IS_SPACE, IS_STOP,
IS_TITLE, IS_UPPER, LEMMA, LENGTH, LIKE_EMAIL,
LIKE_NUM, LIKE_URL, LOWER, MORPH, NORM, ORTH, POS,
SENT_START, SHAPE, SPACY, TAG)
from spacy.tokens import Doc
from rollingwindows import helpers
SPACY_ATTRS = [
"DEP",
"ENT_ID",
"ENT_IOB",
"ENT_TYPE",
"IS_ALPHA",
"IS_ASCII",
"IS_DIGIT",
"IS_LOWER",
"IS_PUNCT",
"IS_SPACE",
"IS_STOP",
"IS_TITLE",
"IS_UPPER",
"LEMMA",
"LENGTH",
"LIKE_EMAIL",
"LIKE_NUM",
"LIKE_URL",
"LOWER",
"MORPH",
"NORM",
"ORTH",
"POS",
"SENT_START",
"SHAPE",
"SPACY",
"TAG",
]
def filter_doc(
doc: spacy.tokens.doc.Doc,
keep_ids: Union[list, set],
spacy_attrs: List[str] = SPACY_ATTRS,
force_ws: bool = True,
) -> spacy.tokens.doc.Doc:
"""Create a filter doc, preserving desired spaCy attributes and whitespace.
Args:
doc (spacy.tokens.doc.Doc): A spaCy doc.
keep_ids (Union[list, set]): The token ids to keep.
spacy_attrs (List[str]): A list of spaCy attributes to preserve.
force_ws (bool): Force a whitespace at the end of every token except the last.
Returns:
spacy.tokens.doc.Doc: A filtered doc.
Note:
In spaCy 3.6.1 `Doc.to_array()` seems to preserve custom attributes.
"""
words = []
remove_indexes = []
for i, token in enumerate(doc):
if i in keep_ids:
words.append(token.text)
else:
remove_indexes.append(i)
np_array = get_doc_array(doc, spacy_attrs, force_ws)
np_array = np.delete(np_array, remove_indexes, axis=0)
doc2 = Doc(doc.vocab, words=words)
doc2.from_array(spacy_attrs, np_array)
return doc2
def get_doc_array(
doc: spacy.tokens.doc.Doc,
spacy_attrs: List[str] = SPACY_ATTRS,
force_ws: bool = True,
) -> np.ndarray:
"""Get a numpy array of the doc.
Args:
doc (spacy.tokens.doc.Doc): A spaCy doc.
spacy_attrs (List[str]): A list of spaCy attributes to preserve.
force_ws (bool): Force a whitespace at the end of every token except the last.
Returns:
np.ndarray: A numpy array of the doc.
Notes:
1. `force_ws=True` ensures that `token_with_ws` and `whitespace_` attributes
are preserved, but all tokens will be separated by whitespaces in the
text of a doc created from the array.
2. `force_ws=False` with `SPACY` in `spacy_attrs` preserves the `token_with_ws`
and `whitespace_` attributes and their original values. This may cause
tokens to be merged if subsequent processing operates on the `doc.text`.
3. `force_ws=False` without `SPACY` in `spacy_attrs` does not preserve the
`token_with_ws` and `whitespace_` attributes or their values. By default,
`doc.text` displays a single space between each token.
"""
if force_ws:
if SPACY not in spacy_attrs:
spacy_attrs.append(SPACY)
np_array = doc.to_array(spacy_attrs)
np_array[:-1, spacy_attrs.index(SPACY)] = 1
# Assume the last item has no whitespace
np_array[-1, spacy_attrs.index(SPACY)] = 0
else:
np_array = doc.to_array(spacy_attrs)
return np_array
def is_not_roman_numeral(s: str) -> bool:
"""Detect Roman numerals (capitals only).
Args:
s (str): A string to match against the pattern.
Returns:
bool: A boolean indicated whether or not the numeral is a Roman numeral.
"""
if s == "":
return True
pattern = r"^M{0,3}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{0,3})$"
return not bool(re.search(pattern, s))
class BaseFilter:
"""A base class for filters."""
@property
def metadata(self) -> dict:
"""Get metadata for the filter object."""
exclude = ["doc"]
metadata = {"id": self.id}
return metadata | dict(
(key, getattr(self, key))
for key in dir(self)
if key not in exclude and key not in dir(self.__class__)
)
class WordFilter(BaseFilter):
"""A filter to remove non-words from a spaCy doc."""
id: str = "word_filter"
def __init__(
self,
doc: spacy.tokens.doc.Doc,
*,
spacy_attrs: List[str] = SPACY_ATTRS,
exclude: Union[List[str], str] = [" ", "\n"],
exclude_digits: bool = False,
exclude_roman_numerals: bool = False,
exclude_pattern: Union[List[str], str] = None,
):
"""Initialise the filter object with configuration.
Args:
doc (spacy.tokens.doc.Doc): A spaCy doc.
spacy_attrs (List[str]): A list of spaCy token attributes to preserve in the filtered doc.
exclude (Union[List[str], str]): A string/regex or list of strings/regex patterns to exclude.
exclude_digits (bool): If True, digits will not be treated as words.
exclude_roman_numerals (bool): Same as above for Roman numerals, but only works on capital letters.
exclude_pattern (Union[List[str], str]): Additional patterns to add to the default exclude list.
"""
self.doc = doc
self.spacy_attrs = spacy_attrs
self.exclude = []
self.exclude_digits = exclude_digits
self.exclude_roman_numerals = exclude_roman_numerals
self.exclude_pattern = []
if exclude:
self.exclude = helpers.ensure_list(exclude)
if exclude_pattern:
self.exclude_pattern = helpers.ensure_list(exclude_pattern)
@property
def word_ids(self):
"""Get a list of word_ids to keep after filtering."""
predicates = []
if self.exclude_digits:
predicates.append(lambda t: t.text.isalpha())
else:
predicates.append(lambda t: t.text.isalpha() or t.text.isdigit())
if self.exclude_roman_numerals:
predicates.append(lambda token: is_not_roman_numeral(token.text))
if self.exclude_pattern:
self.exclude += self.exclude_pattern
if len(self.exclude) > 0:
exclude_pat = "|".join(self.exclude)
predicates.append(lambda token: re.search(exclude_pat, token.text) is None)
return {t.i for t in self.doc if all([f(t) for f in predicates])}
def apply(self) -> spacy.tokens.doc.Doc:
"""Apply the filter.
Returns:
spacy.tokens.doc.Doc: A spaCy Doc.
"""
return filter_doc(self.doc, self.word_ids, self.spacy_attrs)
class NonStopwordFilter(BaseFilter):
"""A filter to remove stop words from a spaCy doc."""
id: str = "non_stopword_filter"
def __init__(
self,
doc: spacy.tokens.doc.Doc,
*,
spacy_attrs: List[str] = SPACY_ATTRS,
additional_stopwords: List[str] = None,
case_sensitive: bool = False,
):
"""Initialise the filter object with configuration.
Args:
doc (spacy.tokens.doc.Doc): A spaCy doc
spacy_attrs (List[str]): A list of spaCy token attributes to preserve in the filtered doc.
additional_stopwords (List[str]): A list of stop words to add to those labelled as stop words by the model.
case_sensitive (bool): Use only lower case forms if False.
Note:
This is a minimal function that strips punctuation and returns words or ids
not flagged as stop words in the doc or in an additional stop words list.
"""
self.doc = doc
self.spacy_attrs = spacy_attrs
self.additional_stopwords = additional_stopwords
self.case_sensitive = case_sensitive
@property
def word_ids(self):
"""Get a list of word_ids to keep after filtering."""
if not self.additional_stopwords:
self.additional_stopwords = set()
else:
self.additional_stopwords = set(
helpers.ensure_list(self.additional_stopwords)
)
if not self.case_sensitive:
self.additional_stopwords = {
text.lower() for text in self.additional_stopwords
}
return {token.i for token in self.doc if self._is_non_stopword(token)}
def _is_non_stopword(self, token: spacy.tokens.Token) -> bool:
"""Check if a token should be retained.
Args:
token (spacy.tokens.Token): A spaCy token
Returns:
bool: True if the token should be retained.
"""
if self.case_sensitive:
text = token.text
else:
text = token.lower_
if (
not token.is_punct
and not token.is_stop
and text not in self.additional_stopwords
):
return True
else:
return False
def apply(self) -> spacy.tokens.doc.Doc:
"""Apply the filter.
Returns:
spacy.tokens.doc.Doc: The filtered doc.
"""
return filter_doc(self.doc, self.word_ids, self.spacy_attrs)