Skip to content

Commit

Permalink
Merge pull request #488 from JohnSnowLabs/context-spell-tf-access
Browse files Browse the repository at this point in the history
Fixed concurrent access to TF in spell checker
  • Loading branch information
saif-ellafi authored Apr 29, 2019
2 parents 0f84d5e + 68ce5a7 commit 0977e10
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
27 changes: 12 additions & 15 deletions src/main/scala/com/johnsnowlabs/ml/tensorflow/TensorflowSpell.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,29 @@ class TensorflowSpell(
val lossKey = "Add:0"
val dropoutRate = "dropout_rate"

/* returns the loss associated with the last word, given previous history */
def predict(dataset: Array[Array[Int]], cids: Array[Array[Int]], cwids:Array[Array[Int]]) = this.synchronized {
// these are the inputs to the graph
val wordIds = "batches:0"
val contextIds = "batches:1"
val contextWordIds = "batches:2"

val packed = dataset.zip(cids).zip(cwids).map {
case ((_ids, _cids), _cwids) => Array(_ids, _cids, _cwids)
}
/* returns the loss associated with the last word, given previous history */
def predict(dataset: Array[Array[Int]], cids: Array[Array[Int]], cwids:Array[Array[Int]]) = {

val tensors = new TensorResources()
val inputTensor = tensors.createTensor(packed)

tensorflow.session.runner
.feed(inMemoryInput, inputTensor)
.addTarget(testInitOp)
.run()
val tensors = new TensorResources

val lossWords = tensorflow.session.runner
.feed(dropoutRate, tensors.createTensor(1.0f))
.feed(wordIds, tensors.createTensor(dataset.map(_.dropRight(1))))
.feed(contextIds, tensors.createTensor(cids.map(_.tail)))
.feed(contextWordIds, tensors.createTensor(cwids.map(_.tail)))
.fetch(lossKey)
.fetch(validWords)
.run()

tensors.clearTensors()

val result = extractFloats(lossWords.get(0))
val width = inputTensor.shape()(2)
result.grouped(width.toInt - 1).map(_.last)

val width = dataset.head.length
result.grouped(width - 1).map(_.last)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import org.apache.hadoop.fs.FileSystem
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.scalatest._
import SparkAccessor.spark
import spark.implicits._


class ContextSpellCheckerTestSpec extends FlatSpec {
Expand Down Expand Up @@ -99,9 +101,6 @@ class ContextSpellCheckerTestSpec extends FlatSpec {


"a Spell Checker" should "work in a pipeline with Tokenizer" in {
import SparkAccessor.spark
import spark.implicits._

val data = Seq("It was a cold , dreary day and the country was white with smow .",
"He wos re1uctant to clange .",
"he is gane .").toDF("text")
Expand Down Expand Up @@ -130,8 +129,6 @@ class ContextSpellCheckerTestSpec extends FlatSpec {

}



"a Spell Checker" should "work in a light pipeline" in {
import SparkAccessor.spark
import spark.implicits._
Expand All @@ -155,10 +152,7 @@ class ContextSpellCheckerTestSpec extends FlatSpec {

val pipeline = new Pipeline().setStages(Array(documentAssembler, tokenizer, spellChecker)).fit(Seq.empty[String].toDF("text"))
val lp = new LightPipeline(pipeline)
lp.annotate(data)
lp.annotate(data)
lp.annotate(data)

lp.annotate(data ++ data ++ data)
}


Expand Down

0 comments on commit 0977e10

Please sign in to comment.