-
Notifications
You must be signed in to change notification settings - Fork 9
/
classification.py
55 lines (41 loc) · 1.46 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Simple classification example.
This script demonstrates how to perform classification using
`maldi-learn` and the DRIAMS data set.
"""
import dotenv
import os
from maldi_learn.driams import DRIAMSDatasetExplorer
from maldi_learn.driams import DRIAMSLabelEncoder
from maldi_learn.driams import load_driams_dataset
from maldi_learn.utilities import stratify_by_species_and_label
from maldi_learn.vectorization import BinningVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
dotenv.load_dotenv()
DRIAMS_ROOT = os.getenv('DRIAMS_ROOT')
explorer = DRIAMSDatasetExplorer(DRIAMS_ROOT)
driams_dataset = load_driams_dataset(
explorer.root,
'DRIAMS-A',
['2015', '2017'],
'Staphylococcus aureus',
['Ciprofloxacin', 'Penicillin'],
encoder=DRIAMSLabelEncoder(),
handle_missing_resistance_measurements='remove_if_all_missing',
)
# bin spectra
bv = BinningVectorizer(100, min_bin=2000, max_bin=20000)
X = bv.fit_transform(driams_dataset.X)
# train-test split
index_train, index_test = stratify_by_species_and_label(
driams_dataset.y, antibiotic='Ciprofloxacin'
)
print(index_train)
print(index_test)
y = driams_dataset.to_numpy('Ciprofloxacin')
print(y[index_train].dtype)
print(X[index_train].shape)
lr = LogisticRegression()
lr.fit(X[index_train], y[index_train])
y_pred = lr.predict(X[index_test])
print(accuracy_score(y_pred, y[index_test]))