From d3bc7474996322030bc83ea76e1c377d52f76199 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Sun, 23 Apr 2017 13:32:39 -0700 Subject: [PATCH] Add LearnerObserver. This enables usage of BIDViz --- src/main/scala/BIDMach/Learner.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/main/scala/BIDMach/Learner.scala b/src/main/scala/BIDMach/Learner.scala index 5648164d..93a85ac3 100755 --- a/src/main/scala/BIDMach/Learner.scala +++ b/src/main/scala/BIDMach/Learner.scala @@ -135,6 +135,9 @@ case class Learner( if (opts.updateAll) { model.dobatchg(mats, ipass, here); if (mixins != null) mixins map (_ compute(mats, here)); + if (opts.observer != null) { + opts.observer.notify(ipass, model, mats) + } if (updater != null) updater.update(ipass, here, gprogress); } val scores = model.evalbatchg(mats, ipass, here); @@ -144,6 +147,9 @@ case class Learner( } else { model.dobatchg(mats, ipass, here) if (mixins != null) mixins map (_ compute(mats, here)) + if (opts.observer != null) { + opts.observer.notify(ipass, model, mats) + } if (updater != null) updater.update(ipass, here, gprogress) } if (datasource.opts.putBack >= 0) datasource.putBack(mats, datasource.opts.putBack) @@ -814,6 +820,11 @@ class ParLearnerF( } object Learner { + trait LearnerObserver { + def init = {} + def cleanup = {} + def notify(ipass:Int, model:Model, minibatch:Array[Mat]) = {} + } class Options extends BIDMat.Opts { var npasses = 2; @@ -827,6 +838,7 @@ object Learner { var cumScore = 0; var checkPointFile:String = null; var checkPointInterval = 0f; + var observer: LearnerObserver = null; } def numBytes(mat:Mat):Long = {