Skip to content

Commit

Permalink
Support mlp classifier (#33)
Browse files Browse the repository at this point in the history
* rename knn_classifier to classifier

* remove deprecated file

* support changing classifier type

* add classifier type flag

* update readme

* linting
  • Loading branch information
santi1234567 authored May 16, 2024
1 parent c7f570d commit 34b8cab
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 23 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,16 @@ pip install -r requirements.txt
pip install -r requirements-dev.txt
```

### k-NN Classifier
### The Classifier

Blockprint's classifier is a k-nearest neighbours classifier in `knn_classifier.py`.
Blockprint's classifier utilizes one of two machine learning algorithms:

See `./knn_classifier.py --help` for command line options including cross
- K-nearest neighbours
- Multi-layer Perceptron

These can be chosen with the `--classifier-type` flag in `classifier.py`.

See `./classifier.py --help` for more command line options including cross
validation (CV) and manual classification.

### Training the Classifier
Expand Down Expand Up @@ -81,10 +86,10 @@ testdata_proc
└── 0x7fedb0da9699c93ce66966555c6719e1159ae7b3220c7053a08c8f50e2f3f56f.json
```

You can then use this directory as the datadir argument to `./knn_classifier.py`:
You can then use this directory as the datadir argument to `./classifier.py`:

```
./knn_classifier.py testdata_proc --classify testdata
./classifier.py testdata_proc --classify testdata
```

If you then want to use the classifier to build an sqlite database:
Expand Down
4 changes: 2 additions & 2 deletions build_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import sqlite3
import argparse
from knn_classifier import Classifier
from classifier import Classifier
from multi_classifier import MultiClassifier
from prepare_training_data import CLIENTS

Expand Down Expand Up @@ -370,7 +370,7 @@ def main():
if args.multi_classifier:
classifier = MultiClassifier(data_dir)
else:
print("loading single KNN classifier")
print("loading single classifier")
classifier = Classifier(data_dir)
print("loaded")

Expand Down
41 changes: 32 additions & 9 deletions knn_classifier.py → classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import pickle

from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_validate
from feature_selection import * # noqa F403
from feature_selection import ALL_FEATURES
from prepare_training_data import CLIENTS, classify_reward_by_graffiti

K = 9

MLP_HIDDEN_LAYER_SIZES = (390, 870)

WEIGHTS = "distance"

MIN_GUESS_THRESHOLD = 0.20
Expand Down Expand Up @@ -69,6 +73,8 @@ def __init__(
graffiti_only_clients=DEFAULT_GRAFFITI_ONLY,
features=DEFAULT_FEATURES,
enable_cv=False,
classifier_type="knn",
hidden_layer_sizes=MLP_HIDDEN_LAYER_SIZES,
):
graffiti_only_clients = set(graffiti_only_clients)

Expand All @@ -82,6 +88,8 @@ def __init__(
set(grouped_clients) & graffiti_only_clients == set()
), "clients must not be both graffiti-only and grouped"

assert classifier_type in ["knn", "mlp"], "classifier_type must be knn or mlp"

feature_matrix = []
training_labels = []

Expand Down Expand Up @@ -118,18 +126,24 @@ def __init__(

feature_matrix = np.array(feature_matrix)

knn = KNeighborsClassifier(n_neighbors=K, weights=WEIGHTS)
if classifier_type == "knn":
classifier = KNeighborsClassifier(n_neighbors=K, weights=WEIGHTS)
elif classifier_type == "mlp":
classifier = MLPClassifier(
hidden_layer_sizes=hidden_layer_sizes, max_iter=1000
)
# Assert above makes sure that classifier_type is one of the valid types

if enable_cv:
self.scores = cross_validate(
knn, feature_matrix, training_labels, scoring="balanced_accuracy"
classifier, feature_matrix, training_labels, scoring="balanced_accuracy"
)
else:
self.scores = None

knn.fit(feature_matrix, training_labels)
classifier.fit(feature_matrix, training_labels)

self.knn = knn
self.classifier = classifier
self.enabled_clients = enabled_clients
self.graffiti_only_clients = set(graffiti_only_clients)
self.features = features
Expand All @@ -145,7 +159,7 @@ def classify(self, block_reward):
return (graffiti_guess, graffiti_guess, prob_by_client, graffiti_guess)

row = into_feature_row(block_reward, self.features)
res = self.knn.predict_proba([row])
res = self.classifier.predict_proba([row])

prob_by_client = {
client: res[0][i] for i, client in enumerate(self.enabled_clients)
Expand Down Expand Up @@ -219,7 +233,7 @@ def compute_best_guess(probability_map) -> str:


def parse_args():
parser = argparse.ArgumentParser("KNN testing and cross validation")
parser = argparse.ArgumentParser("Classifier testing and cross validation")

parser.add_argument("data_dir", help="training data directory")
parser.add_argument("--classify", help="data to classify")
Expand All @@ -235,6 +249,12 @@ def parse_args():
parser.add_argument(
"--group", default=[], nargs="+", help="clients to group during classification"
)
parser.add_argument(
"--classifier-type",
default="knn",
choices=["knn", "mlp"],
help="the type of classifier to use",
)
parser.add_argument(
"--persist",
action="store_true",
Expand Down Expand Up @@ -280,7 +300,7 @@ def main():
grouped_clients = args.group
should_persist = args.should_persist
graffiti_only = args.graffiti_only

classifier_type = args.classifier_type
disabled_clients = args.disable
enabled_clients = [
client
Expand Down Expand Up @@ -310,6 +330,7 @@ def main():
graffiti_only_clients=graffiti_only,
features=feature_vec,
enable_cv=True,
classifier_type=classifier_type,
)
print(f"enabled clients: {classifier.enabled_clients}")
print(f"classifier scores: {classifier.scores['test_score']}")
Expand All @@ -327,7 +348,9 @@ def main():
assert classify_dir is not None, "classify dir required"
print(f"classifying all data in directory {classify_dir}")
print(f"grouped clients: {grouped_clients}")
classifier = Classifier(data_dir, grouped_clients=grouped_clients)
classifier = Classifier(
data_dir, grouped_clients=grouped_clients, classifier_type=classifier_type
)

if args.plot is not None:
classifier.plot_feature_matrix(args.plot)
Expand All @@ -354,7 +377,7 @@ def main():
print(f"total blocks processed: {total_blocks}")

if should_persist:
persist_classifier(classifier, "knn_classifier")
persist_classifier(classifier, "classifier")

for multilabel, num_blocks in sorted(frequency_map.items()):
percentage = round(num_blocks / total_blocks, 4)
Expand Down
2 changes: 1 addition & 1 deletion compute_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sqlite3
import requests
import statistics
from knn_classifier import compute_best_guess
from classifier import compute_best_guess
from prepare_training_data import CLIENTS
from build_db import block_row_to_obj

Expand Down
4 changes: 2 additions & 2 deletions interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"from knn_classifier import Classifier, DEFAULT_FEATURES"
"from classifier import Classifier, DEFAULT_FEATURES"
]
},
{
Expand All @@ -19,7 +19,7 @@
"source": [
"datadir = \"data/mainnet/training/slots_3481601_to_3702784_bal2x\"\n",
"disabled_clients = []\n",
"features = ['percent_redundant', 'percent_pairwise_ordered', 'norm_reward']\n",
"features = [\"percent_redundant\", \"percent_pairwise_ordered\", \"norm_reward\"]\n",
"\n",
"classifier = Classifier(datadir, disabled_clients=disabled_clients, features=features)"
]
Expand Down
2 changes: 1 addition & 1 deletion multi_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from knn_classifier import Classifier
from classifier import Classifier


def start_and_end_slot(sub_dir_name) -> (int, int):
Expand Down
4 changes: 2 additions & 2 deletions prepare_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def process_file(


def parse_args():
parser = argparse.ArgumentParser("create training data for the KNN classifier")
parser = argparse.ArgumentParser("create training data for the classifier")

parser.add_argument(
"raw_data_dir", help="input containing data to classify using graffiti"
)
parser.add_argument(
"proc_data_dir", help="output for processed data, suitable for KNN training"
"proc_data_dir", help="output for processed data, suitable for training"
)
parser.add_argument(
"--disable",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classifier_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from typing import Any, Dict, List
from knn_classifier import Classifier, persist_classifier
from classifier import Classifier, persist_classifier
from prepare_training_data import CLIENTS


Expand Down

0 comments on commit 34b8cab

Please sign in to comment.