Skip to content

A Simple Extreme Learning Machine implementation in Scala With Breeze

License

Notifications You must be signed in to change notification settings

sirCamp/simple-extreme-learning-machine

Repository files navigation

Build Status License: MIT Scala

A simple implementation of Extreme Learning Machine

This is a simple scala & Breeze implementation of an Extreme Learning Machine.

This library is designed to have great performances in scientific computation thanks to the linear algebra optimizations.

How to

In order to use take a look to the following example:

 import com.sircamp.elm.ExtremeLearningMachine


 var featuresLength = 28*28 //MINST dataset
 var hiddenLayerDimension = 1024
 val elm = new ExtremeLearningMachine(featuresLength, hiddenLayerDimension)
 
 /**
  Initialize the weights of the hidden layer with random uniform distribution
  Otherwise you can set the weights by your own. 
  Weights must be a DenseMatrix[Double] where rows are equal to the featuresLength
 **/
 elm.initializeWeights() 

 /**
  fit the model.
  XTrain and yTrain must be DenseMatrix[Double].
  yTrain must be the one hot encoded version of the original label
**/
 elm.fit(XTrain, yTrain)

/**
  predictClasses return a DenseVector[Int] with the index of the predicted class 
**/
 var yPred = elm.predictClasses(XTest)


/**
  predict return a DenseMatrix[Double] where each row contains the probability of the element to belongs to the class 
**/
 var yProbabilityPred = elm.predict(XTest)

 println("Accuracy: "+Metrics.accuracy_score(yPlain,yPred))

Set a different Activation function

import com.sircamp.elm.ExtremeLearningMachine


 var featuresLength = 28*28 //MINST dataset
 var hiddenLayerDimension = 1024
 val elm = new ExtremeLearningMachine(featuresLength, hiddenLayerDimension)
 
 /**
  Initialize the weights of the hidden layer with random uniform distribution
  Otherwise you can set the weights by your own. 
  Weights must be a DenseMatrix[Double] where rows are equal to the featuresLength
 **/
 elm.initializeWeights() 

 /**
   This return the LeakyReLu function with the alpha param
 **/
 elm.setActivationFunction(ActivationFunctions.leakyReLu(0.2))


 /**
   This return the Tanh function
 **/
 elm.setActivationFunction(ActivationFunctions.tanh)

For more example take a look to the tests