Skip to content

Commit

Permalink
Merge pull request #78 from JohnSnowLabs/serialization_features_bc
Browse files Browse the repository at this point in the history
Serialization features bc
  • Loading branch information
saif-ellafi authored Jan 8, 2018
2 parents e15ae81 + 8208245 commit 3fd55f8
Show file tree
Hide file tree
Showing 14 changed files with 354 additions and 181 deletions.
70 changes: 38 additions & 32 deletions python/example/vivekn-sentiment/sentiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,15 @@
"outputs": [],
"source": [
"#Imports\n",
"import time\n",
"import sys\n",
"sys.path.append('../../')\n",
"\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.ml import Pipeline, PipelineModel\n",
"from sparknlp.annotator import *\n",
"from sparknlp.base import DocumentAssembler, Finisher\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"spark = SparkSession.builder \\\n",
" .master(\"local[2]\") \\\n",
" .config(\"spark.jar\", \"lib/sparknlp.jar\") \\\n",
" .config(\"spark.driver.memory\", \"5g\")\\\n",
" .config(\"spark.dirver.maxResultSize\", \"2g\")\\\n",
" .getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -137,9 +120,22 @@
"sentiment_detector = ViveknSentimentApproach() \\\n",
" .setInputCols([\"spell\", \"sentence\"]) \\\n",
" .setOutputCol(\"sentiment\") \\\n",
" .setPruneCorpus(False) \\\n",
" .setPositiveSource(\"../../../src/test/resources/vivekn/positive\") \\\n",
" .setNegativeSource(\"../../../src/test/resources/vivekn/negative\") \\\n",
" .setPruneCorpus(False)\n"
" .setNegativeSource(\"../../../src/test/resources/vivekn/negative\") \\\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"pos = PerceptronApproach() \\\n",
" .setInputCols([\"sentence\", \"spell\"]) \\\n",
" .setOutputCol(\"pos\")"
]
},
{
Expand Down Expand Up @@ -168,11 +164,15 @@
" normalizer,\n",
" spell_checker,\n",
" sentiment_detector,\n",
" pos,\n",
" finisher\n",
"])\n",
"\n",
"start = time.time()\n",
"sentiment_data = pipeline.fit(data).transform(data)\n",
"sentiment_data.show()"
"sentiment_data.show()\n",
"end = time.time()\n",
"print(\"Time elapsed pipeline process: \" + str(end - start))"
]
},
{
Expand All @@ -188,24 +188,27 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"start = time.time()\n",
"pipeline.write().overwrite().save(\"./ps\")\n",
"pipeline.fit(data).write().overwrite().save(\"./ms\")"
"pipeline.fit(data).write().overwrite().save(\"./ms\")\n",
"end = time.time()\n",
"print(\"Time elapsed in write pipelines: \" + str(end - start))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml import Pipeline,PipelineModel"
"start = time.time()\n",
"p = Pipeline.read().load(\"./ps\")\n",
"pm = PipelineModel.read().load(\"./ms\")\n",
"end = time.time()\n",
"print(\"Time elapsed in read pipelines: \" + str(end - start))"
]
},
{
Expand All @@ -214,8 +217,11 @@
"metadata": {},
"outputs": [],
"source": [
"Pipeline.read().load(\"./ps\")\n",
"PipelineModel.read().load(\"./ms\")"
"start = time.time()\n",
"pm.transform(data).where(\"finished_sentiment not like '%negative%'\").show()\n",
"print(pm.transform(data).count())\n",
"end = time.time()\n",
"print(\"Time elapsed in using loaded pipelines: \" + str(end - start))"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.johnsnowlabs.nlp

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.DefaultParamsWritable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types._
Expand All @@ -15,7 +14,7 @@ import org.apache.spark.sql.functions.{array, udf}
*/
abstract class AnnotatorModel[M <: Model[M]]
extends Model[M]
with DefaultParamsWritable
with ParamsAndFeaturesWritable
with HasAnnotatorType
with HasInputAnnotationCols
with HasOutputAnnotationCol {
Expand Down
35 changes: 35 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.serialization.{ArrayFeature, Feature, MapFeature, StructFeature}

import scala.collection.mutable.ArrayBuffer

trait HasFeatures {

val features: ArrayBuffer[Feature[_, _, _]] = ArrayBuffer.empty

protected def set[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this}

protected def set[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this}

protected def set[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}

protected def get[T](feature: ArrayFeature[T]): Option[Array[T]] = feature.get

protected def get[K, V](feature: MapFeature[K, V]): Option[Map[K, V]] = feature.get

protected def get[T](feature: StructFeature[T]): Option[T] = feature.get

protected def $$[T](feature: ArrayFeature[T]): Array[T] = feature.getValue

protected def $$[K, V](feature: MapFeature[K, V]): Map[K, V] = feature.getValue

protected def $$[T](feature: StructFeature[T]): T = feature.getValue

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader}
import org.apache.spark.sql.SparkSession

class FeaturesReader[T <: HasFeatures](baseReader: MLReader[T], onRead: (T, String, SparkSession) => Unit) extends MLReader[T] {

override def load(path: String): T = {

val instance = baseReader.load(path)

for (feature <- instance.features) {
val value = feature.deserialize(sparkSession, path, feature.name)
feature.setValue(value)
}

onRead(instance, path, sparkSession)

instance
}
}

trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] {

def onRead(instance: T, path: String, spark: SparkSession): Unit = {}

override def read: MLReader[T] = new FeaturesReader(
super.read,
(instance: T, path: String, spark: SparkSession) => onRead(instance, path, spark)
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{DefaultParamsWritable, MLWriter}
import org.apache.spark.sql.SparkSession

class FeaturesWriter[T](annotatorWithFeatures: HasFeatures, baseWriter: MLWriter, onWritten: (String, SparkSession) => Unit)
extends MLWriter with HasFeatures {

override protected def saveImpl(path: String): Unit = {
baseWriter.save(path)

for (feature <- annotatorWithFeatures.features) {
feature.serializeInfer(sparkSession, path, feature.name, feature.getValue)
}

onWritten(path, sparkSession)

}
}

trait ParamsAndFeaturesWritable extends DefaultParamsWritable with Params with HasFeatures {

def onWritten(path: String, spark: SparkSession): Unit = {}

override def write: MLWriter = new FeaturesWriter(
this,
super.write,
(path: String, spark: SparkSession) => onWritten(path, spark)
)

}
29 changes: 16 additions & 13 deletions src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.annotators.common.StringMapParam
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, ParamsAndFeaturesReadable}
import com.typesafe.config.Config
import com.johnsnowlabs.nlp.util.ConfigHelper
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.ml.util.Identifiable

import scala.collection.JavaConverters._

Expand All @@ -25,7 +25,7 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {

private val config: Config = ConfigHelper.retrieve

val lemmaDict: StringMapParam = new StringMapParam(this, "lemmaDict", "provide a lemma dictionary")
val lemmaDict: MapFeature[String, String] = new MapFeature(this, "lemmaDict")

val lemmaFormat: Param[String] = new Param[String](this, "lemmaFormat", "TXT or TXTDS for reading dictionary as dataset")

Expand All @@ -52,15 +52,23 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {

def this() = this(Identifiable.randomUID("LEMMATIZER"))

def getLemmaDict: Map[String, String] = $(lemmaDict)
def getLemmaDict: Map[String, String] = $$(lemmaDict)
protected def getLemmaFormat: String = $(lemmaFormat)
protected def getLemmaKeySep: String = $(lemmaKeySep)
protected def getLemmaValSep: String = $(lemmaValSep)

def setLemmaDict(dictionary: String): this.type = {
set(lemmaDict, Lemmatizer.retrieveLemmaDict(dictionary, $(lemmaFormat), $(lemmaKeySep), $(lemmaValSep)))
}

def setLemmaDictHMap(dictionary: java.util.HashMap[String, String]): this.type = {
set(lemmaDict, dictionary.asScala.toMap)
}
def setLemmaDictMap(dictionary: Map[String, String]): this.type = {
set(lemmaDict, dictionary)
}
def setLemmaFormat(value: String): this.type = set(lemmaFormat, value)
def setLemmaKeySep(value: String): this.type = set(lemmaKeySep, value)
def setLemmaValSep(value: String): this.type = set(lemmaValSep, value)

/**
* @return one to one annotation from token to a lemmatized word, if found on dictionary or leave the word as is
Expand All @@ -72,19 +80,14 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {
annotatorType,
tokenAnnotation.begin,
tokenAnnotation.end,
$(lemmaDict).getOrElse(token, token),
$$(lemmaDict).getOrElse(token, token),
tokenAnnotation.metadata
)
}
}
}

object Lemmatizer extends DefaultParamsReadable[Lemmatizer] {

/**
* Retrieves Lemma dictionary from configured compiled source set in configuration
* @return a Dictionary for lemmas
*/
object Lemmatizer extends ParamsAndFeaturesReadable[Lemmatizer] {
protected def retrieveLemmaDict(
lemmaFilePath: String,
lemmaFormat: String,
Expand Down
Loading

0 comments on commit 3fd55f8

Please sign in to comment.