Skip to content

Commit

Permalink
v0.0.13
Browse files Browse the repository at this point in the history
  • Loading branch information
rosefun committed Nov 21, 2020
1 parent 50ddec6 commit 7f8a310
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
37 changes: 37 additions & 0 deletions examples/run_S3VM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import absolute_import
import numpy as np
from sklearn import datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split


# normalization
def normalize(x):
return (x - np.min(x))/(np.max(x) - np.min(x))

def get_data():
X, y = datasets.load_breast_cancer(return_X_y=True)
X = normalize(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.6, random_state = 0)
rng = np.random.RandomState(42)
random_unlabeled_points = rng.rand(len(X_train)) < 0.1
y_train[random_unlabeled_points] = -1
#
index, = np.where(y_train != -1)
label_X_train = X_train[index,:]
label_y_train = y_train[index]
index, = np.where(y_train == -1)
unlabel_X_train = X_train[index,:]
unlabel_y = -1*np.ones(unlabel_X_train.shape[0]).astype(int)
return label_X_train, label_y_train, unlabel_X_train, unlabel_y, X_test, y_test

if __name__ == "__main__":
from semisupervised import S3VM

label_X_train, label_y_train, unlabel_X_train, unlabel_y, X_test, y_test = get_data()
# S3VM
model = S3VM()
model.fit(np.vstack((label_X_train,unlabel_X_train)), np.append(label_y_train, unlabel_y))
predict = model.predict(X_test)
acc = metrics.accuracy_score(y_test, predict)
print("S3VM accuracy", acc)
10 changes: 7 additions & 3 deletions examples/run_pseudo_label_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import metrics

class DNN(object):
"""
Expand Down Expand Up @@ -64,9 +65,12 @@ def get_data():

from semisupervised.PseudoLabelSSL import PseudoCallback, PseudoLabelNeuralNetworkClassifier
pseudo_callback = PseudoCallback()
#print("pseudo_callback.pretrain", pseudo_callback.pretrain)
clf = PseudoLabelNeuralNetworkClassifier(DNNmodel, pseudo_callback)
clf.fit(np.vstack((label_X_train, unlabel_X_train)), np.append(label_y_train, unlabel_y))

model = PseudoLabelNeuralNetworkClassifier(DNNmodel, pseudo_callback)
model.fit(np.vstack((label_X_train, unlabel_X_train)), np.append(label_y_train, unlabel_y))
predict = model.predict(X_test)
acc = metrics.accuracy_score(y_test, predict)
print("pseudo-label accuracy", acc)



0 comments on commit 7f8a310

Please sign in to comment.