-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_feature_selection.py
60 lines (47 loc) · 1.31 KB
/
run_feature_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys
import warnings
import pickle
import click
from .utils import constants, load_splits, load_best_pipeline, get_simple_pipeline
if not sys.warnoptions:
warnings.simplefilter("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"
@click.command()
@click.argument(
'path-cls-dir',
type=click.Path(exists=True),
)
@click.argument(
'path-data-dir',
type=click.Path(exists=True),
)
def main(
path_cls_dir,
path_data_dir,
):
clf_pipeline = load_best_pipeline(path_cls_dir)
data_train, _, cv_splits = load_splits(path_data_dir)
X_train = data_train.drop(columns=['result translation'])
y_train = data_train['result translation']
prep_pipeline = get_simple_pipeline(
X_train, constants.FEATURES_CAT
)
prep_pipeline.fit(X_train)
X_train_transformed = prep_pipeline.transform(X_train)
from mlxtend.feature_selection import SequentialFeatureSelector
sfs = SequentialFeatureSelector(
clf_pipeline,
k_features='parsimonious',
forward=True,
floating=False,
verbose=2,
scoring='f1',
cv=cv_splits,
n_jobs=-1,
)
sfs.fit(X_train_transformed, y_train)
with open(f"{path_cls_dir}/classifiers/sfs.pkl", 'wb') as f:
pickle.dump(sfs, f)
if __name__ == '__main__':
main()