-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_custom
executable file
·46 lines (40 loc) · 1.31 KB
/
test_custom
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env Rscript
rm(list=ls())
## * Package setup
library(methods, quietly=TRUE) ## hasArgs is not loaded by default by Rscript
library(devtools, quietly=TRUE)
pkg_path <- "."
try(unload(pkg_path))
document(pkg_path)
load_all(pkg_path)
test(pkg_path)
## * Loading data set
data(mini_mnist)
feats <- mini_mnist$train$images
targets <- mini_mnist$train$labels
feats_oos <- mini_mnist$test$images
targets_oos <- mini_mnist$test$labels
## * Defining model
## ** initial weights
weights_init <- matrix(rnorm(NCOL(feats) * NCOL(targets)), NCOL(feats),
NCOL(targets))
## * Training model
## ** Setting training variables
max_iter <- Inf
decay <- 0
step_size <- 0.01
verbose <- FALSE
tol <- 1e-6
## ** Training models and outputing misclassification rates
for (backend in c("R", "C", "CUDA")) {
model <- Classifier(weights_init=weights_init)
model <- train(model, feats, targets, decay=decay, step_size=step_size,
max_iter=max_iter, verbose=verbose,
tol=tol, backend=backend)
mis_is <- get_error(model, feats, targets)
mis_oos <- get_error(model, feats_oos, targets_oos)
print(paste(backend, "backend in sample misclassification rate"))
print(mis_is)
print(paste(backend, "backend out of sample misclassification rate"))
print(mis_oos)
}