Skip to content

Commit

Permalink
Merge pull request #70 from JohnSnowLabs/word_embeddings
Browse files Browse the repository at this point in the history
Word embeddings
  • Loading branch information
saif-ellafi authored Jan 8, 2018
2 parents fc147a0 + 6f30421 commit 680d5ef
Show file tree
Hide file tree
Showing 28 changed files with 608 additions and 156 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ lazy val testDependencies = Seq(

lazy val utilDependencies = Seq(
"com.typesafe" % "config" % "1.3.0",
"org.fusesource.leveldbjni" % "leveldbjni-all" % "1.8"
"org.rocksdb" % "rocksdbjni" % "5.8.0",
"org.slf4j" % "slf4j-api" % "1.7.25",
"org.apache.commons" % "commons-compress" % "1.15"
)

lazy val root = (project in file("."))
Expand Down
22 changes: 21 additions & 1 deletion python/sparknlp/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ def setOutputCol(self, value):
return self._set(outputCol=value)


class AnnotatorWithEmbeddings(Params):
sourceEmbeddingsPath = Param(Params._dummy(),
"sourceEmbeddingsPath",
"Word embeddings file",
typeConverter=TypeConverters.toString)
embeddingsFormat = Param(Params._dummy(),
"embeddingsFormat",
"Word vectors file format",
typeConverter=TypeConverters.toInt)
embeddingsNDims = Param(Params._dummy(),
"embeddingsNDims",
"Number of dimensions for word vectors",
typeConverter=TypeConverters.toInt)

def setEmbeddingsSource(self, path, nDims, format):
self._set(sourceEmbeddingsPath=path)
self._set(embeddingsFormat=format)
return self._set(embeddingsNDims=nDims)


class AnnotatorTransformer(JavaModel, JavaMLReadable, JavaMLWritable, AnnotatorProperties):

column_type = "array<struct<annotatorType:string,begin:int,end:int,metadata:map<string,string>>>"
Expand Down Expand Up @@ -478,7 +498,7 @@ class NorvigSweetingModel(JavaModel, JavaMLWritable, JavaMLReadable, AnnotatorPr



class NerCrfApproach(JavaEstimator, JavaMLWritable, JavaMLReadable, AnnotatorProperties):
class NerCrfApproach(JavaEstimator, JavaMLWritable, JavaMLReadable, AnnotatorProperties, AnnotatorWithEmbeddings):
labelColumn = Param(Params._dummy(),
"labelColumn",
"Column with label per each token",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ class DatasetEncoder(val startLabel: String = "@#Start") {
result
}
}
}
}
2 changes: 0 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/crf/DatasetReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,3 @@ object DatasetReader {
}
}
}


12 changes: 9 additions & 3 deletions src/main/scala/com/johnsnowlabs/nlp/AnnotatorApproach.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.types.{ArrayType, MetadataBuilder, StructField, StructType}
import org.apache.spark.ml.util.DefaultParamsWritable

Expand All @@ -24,8 +24,15 @@ abstract class AnnotatorApproach[M <: Model[M]]

def train(dataset: Dataset[_]): M

def beforeTraining(spark: SparkSession): Unit = {}

def onTrained(model: M, spark: SparkSession): Unit = {}

override final def fit(dataset: Dataset[_]): M = {
copyValues(train(dataset).setParent(this))
beforeTraining(dataset.sparkSession)
val model = copyValues(train(dataset).setParent(this))
onTrained(model, dataset.sparkSession)
model
}

override final def copy(extra: ParamMap): Estimator[M] = defaultCopy(extra)
Expand All @@ -50,5 +57,4 @@ abstract class AnnotatorApproach[M <: Model[M]]
StructField(getOutputCol, ArrayType(Annotation.dataType), nullable = false, metadataBuilder.build)
StructType(outputFields)
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package com.johnsnowlabs.nlp.annotators.ner.crf

import com.johnsnowlabs.nlp.util.io.ResourceHelper

import scala.io.Source

case class DictionaryFeatures(dict: Map[String, String])
{
def get(tokens: Seq[String]): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.johnsnowlabs.nlp.annotators.ner.crf

import com.johnsnowlabs.ml.crf._
import com.johnsnowlabs.nlp.annotators.common.{TaggedSentence, WordEmbeddings}
import com.johnsnowlabs.nlp.annotators.common.TaggedSentence
import com.johnsnowlabs.nlp.embeddings.WordEmbeddings

import scala.collection.mutable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.johnsnowlabs.nlp.annotators.common.NerTagged
import com.johnsnowlabs.nlp.annotators.pos.perceptron.PerceptronApproach
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetectorModel
import com.johnsnowlabs.nlp.datasets.CoNLL
import com.johnsnowlabs.nlp.embeddings.AnnotatorWithWordEmbeddings
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, StringArrayParam}
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
Expand All @@ -17,7 +18,9 @@ import org.apache.spark.sql.{DataFrame, Dataset}
/*
Algorithm for training Named Entity Recognition Model.
*/
class NerCrfApproach(override val uid: String) extends AnnotatorApproach[NerCrfModel]{
class NerCrfApproach(override val uid: String)
extends AnnotatorWithWordEmbeddings[NerCrfApproach, NerCrfModel] {

def this() = this(Identifiable.randomUID("NER"))

override val description = "CRF based Named Entity Recognition Tagger"
Expand Down Expand Up @@ -116,7 +119,8 @@ class NerCrfApproach(override val uid: String) extends AnnotatorApproach[NerCrfM

val dictPaths = get(dicts).getOrElse(Array.empty[String])
val dictFeatures = DictionaryFeatures.read(dictPaths.toSeq)
val crfDataset = FeatureGenerator(dictFeatures).generateDataset(trainDataset)
val crfDataset = FeatureGenerator(dictFeatures, embeddings)
.generateDataset(trainDataset)

val params = CrfParams(
minEpochs = getOrDefault(minEpochs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.johnsnowlabs.ml.crf.{LinearChainCrfModel, SerializedLinearChainCrfMod
import com.johnsnowlabs.nlp.AnnotatorType._
import com.johnsnowlabs.nlp.annotators.common.{IndexedTaggedWord, NerTagged, PosTagged, TaggedSentence}
import com.johnsnowlabs.nlp.annotators.common.Annotated.{NerTaggedSentence, PosTaggedSentence}
import com.johnsnowlabs.nlp.embeddings.ModelWithWordEmbeddings
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel}
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.param.StringArrayParam
Expand All @@ -14,8 +15,7 @@ import org.apache.spark.sql.{Encoders, Row}
/*
Named Entity Recognition model
*/
class NerCrfModel(override val uid: String)
extends AnnotatorModel[NerCrfModel] {
class NerCrfModel(override val uid: String) extends ModelWithWordEmbeddings[NerCrfModel]{

def this() = this(Identifiable.randomUID("NER"))

Expand All @@ -36,7 +36,7 @@ class NerCrfModel(override val uid: String)
def setEntities(toExtract: Array[String]): NerCrfModel = set(entities, toExtract)

/**
Predicts Named Entities in input sentences
Predicts Named Entities in input sentences
* @param sentences POS tagged sentences.
* @return sentences with recognized Named Entities
*/
Expand All @@ -45,8 +45,9 @@ class NerCrfModel(override val uid: String)

val crf = model.get

val fg = FeatureGenerator(dictionaryFeatures, embeddings)
sentences.map{sentence =>
val instance = FeatureGenerator(dictionaryFeatures).generate(sentence, crf.metadata)
val instance = fg.generate(sentence, crf.metadata)
val labelIds = crf.predict(instance)
val words = sentence.indexedTaggedWords
.zip(labelIds.labels)
Expand Down Expand Up @@ -116,6 +117,9 @@ object NerCrfModel extends DefaultParamsReadable[NerCrfModel] {
instance
.setModel(crfModel.deserialize)
.setDictionaryFeatures(dictFeatures)

instance.deserializeEmbeddings(path, sparkSession.sparkContext)
instance
}
}

Expand All @@ -136,7 +140,8 @@ object NerCrfModel extends DefaultParamsReadable[NerCrfModel] {
val dictPath = new Path(path, "dict").toString
val dictLines = model.dictionaryFeatures.dict.toSeq.map(p => p._1 + ":" + p._2)
Seq(dictLines).toDS.write.mode("overwrite").parquet(dictPath)

model.serializeEmbeddings(path, sparkSession.sparkContext)
}
}
}

4 changes: 2 additions & 2 deletions src/main/scala/com/johnsnowlabs/nlp/datasets/CoNLL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class CoNLL(targetColumn: Int = 3, annotatorType: String) {

def pack(sentences: Seq[TaggedSentence]): Seq[Annotation] = {
if (annotatorType == AnnotatorType.NAMED_ENTITY)
NerTagged.pack(sentences)
NerTagged.pack(sentences)
else
PosTagged.pack(sentences)
}
Expand All @@ -99,4 +99,4 @@ case class CoNLL(targetColumn: Int = 3, annotatorType: String) {
val seq = readLines(lines).map(p => (p._1, pack(p._2)))
seq.toDF(textColumn, labelColumn)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.johnsnowlabs.nlp.datasets

import java.io.File

import com.johnsnowlabs.ml.crf.{CrfDataset, DatasetMetadata, InstanceLabels, TextSentenceLabels}
import com.johnsnowlabs.nlp.AnnotatorType
import com.johnsnowlabs.nlp.annotators.common.TaggedSentence
import com.johnsnowlabs.nlp.annotators.ner.crf.{DictionaryFeatures, FeatureGenerator}
import com.johnsnowlabs.nlp.embeddings.{WordEmbeddings, WordEmbeddingsFormat, WordEmbeddingsIndexer}

/**
* Helper class for to work with CoNLL 2003 dataset for NER task
* Class is made for easy use from Java
*/
class CoNLL2003NerReader(wordEmbeddingsFile: String,
wordEmbeddingsNDims: Int,
embeddingsFormat: WordEmbeddingsFormat.Format,
dictionaryFile: String) {

private val nerReader = CoNLL(3, AnnotatorType.NAMED_ENTITY)
private val posReader = CoNLL(1, AnnotatorType.POS)

private var wordEmbeddings: Option[WordEmbeddings] = None

if (wordEmbeddingsFile != null) {
require(new File(wordEmbeddingsFile).exists())

var fileDb = wordEmbeddingsFile + ".db"

if (!new File(fileDb).exists()) {
embeddingsFormat match {
case WordEmbeddingsFormat.Text =>
WordEmbeddingsIndexer.indexText(wordEmbeddingsFile, fileDb)
case WordEmbeddingsFormat.Binary =>
WordEmbeddingsIndexer.indexBinary(wordEmbeddingsFile, fileDb)
case WordEmbeddingsFormat.SparkNlp =>
fileDb = wordEmbeddingsFile
}

}

if (new File(fileDb).exists()) {
wordEmbeddings = Some(WordEmbeddings(fileDb, wordEmbeddingsNDims))
}
}

private val dicts = if (dictionaryFile == null) Seq.empty[String] else Seq(dictionaryFile)

private val fg = FeatureGenerator(
DictionaryFeatures.read(dicts),
wordEmbeddings
)

private def readDataset(file: String): Seq[(TextSentenceLabels, TaggedSentence)] = {
val labels = nerReader.readDocs(file).flatMap(_._2)
.map(sentence => TextSentenceLabels(sentence.tags))

val posTaggedSentences = posReader.readDocs(file).flatMap(_._2)
labels.zip(posTaggedSentences)
}

def readNerDataset(file: String, metadata: Option[DatasetMetadata] = None): CrfDataset = {
val lines = readDataset(file)
if (metadata.isEmpty)
fg.generateDataset(lines)
else {
val labeledInstances = lines.map { line =>
val instance = fg.generate(line._2, metadata.get)
val labels = InstanceLabels(line._1.labels.map(l => metadata.get.label2Id.getOrElse(l, -1)))
(labels, instance)
}
CrfDataset(labeledInstances, metadata.get)
}
}
}
Loading

0 comments on commit 680d5ef

Please sign in to comment.