Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-823 Adding streaming functionality for seq2seq components #13899

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/sparknlp/base/light_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sparknlp.annotation import Annotation
from sparknlp.annotation_audio import AnnotationAudio
from sparknlp.annotation_image import AnnotationImage
from sparknlp.common import AnnotatorApproach, AnnotatorModel
from sparknlp.common import AnnotatorApproach, AnnotatorModel, isSeq2Seq
from sparknlp.internal import AnnotatorTransformer


Expand Down Expand Up @@ -315,7 +315,7 @@ def __buildStages(self, annotations_result):
stages[annotator_type] = self._annotationFromJava(annotations)
return stages

def annotate(self, target, optional_target=""):
def annotate(self, target, optional_target="", streamer=False):
"""Annotates the data provided, extracting the results.

The data should be either a list or a str.
Expand Down Expand Up @@ -347,13 +347,16 @@ def reformat(annotations):
return {k: list(v) for k, v in annotations.items()}

stages = self.pipeline_model.stages

if not self._skipPipelineValidation(stages):
self._validateStagesInputCols(stages)

if optional_target == "":
if type(target) is str:
annotations = self._lightPipeline.annotateJava(target)
result = reformat(annotations)
if streamer:
self.__streamResult(stages, result)
elif type(target) is list:
if type(target[0]) is list:
raise TypeError("target is a 1D list")
Expand Down Expand Up @@ -416,3 +419,13 @@ def getIgnoreUnsupported(self):
Whether to ignore unsupported AnnotatorModels.
"""
return self._lightPipeline.getIgnoreUnsupported()

def __streamResult(self, stages, result):
seq2seq_output_cols = self.__getSeq2SeqOutputCols(stages)
for seq2seq_output_col in seq2seq_output_cols:
if seq2seq_output_col in result:
print(f"{result[seq2seq_output_col][0]}")

def __getSeq2SeqOutputCols(self, stages):
seq2seq_stages = [stage for stage in stages if isSeq2Seq(stage)]
return [stage.getOutputCol() for stage in seq2seq_stages]
5 changes: 5 additions & 0 deletions python/sparknlp/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities for annotators."""
import sys

from sparknlp.common.read_as import ReadAs
import sparknlp.internal as _internal
Expand All @@ -37,3 +38,7 @@ def ExternalResource(path, read_as=ReadAs.TEXT, options={}):
def RegexRule(rule, identifier):
return _internal._RegexRule(rule, identifier).apply()


def isSeq2Seq(instance):
module = sys.modules[instance.__class__.__module__]
return module.__name__.startswith("sparknlp.annotator.seq2seq")
34 changes: 25 additions & 9 deletions src/main/scala/com/johnsnowlabs/nlp/LightPipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.annotator.isSeq2SeqTransformer
import com.johnsnowlabs.nlp.annotators.cv.util.io.ImageIOUtils
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import org.apache.spark.ml.{PipelineModel, Transformer}
Expand Down Expand Up @@ -393,16 +394,25 @@ class LightPipeline(val pipelineModel: PipelineModel, parseEmbeddings: Boolean =
.toArray
}

def annotate(target: String, optionalTarget: String = ""): Map[String, Seq[String]] = {
fullAnnotate(target, optionalTarget).mapValues(_.map { iAnnotation =>
val annotation = iAnnotation.asInstanceOf[Annotation]
annotation.annotatorType match {
case AnnotatorType.WORD_EMBEDDINGS | AnnotatorType.SENTENCE_EMBEDDINGS
if parseEmbeddings =>
annotation.embeddings.mkString(" ")
case _ => annotation.result
def annotate(
target: String,
optionalTarget: String = "",
streamer: Boolean = false): Map[String, Seq[String]] = {
fullAnnotate(target, optionalTarget).map { case (outputCol, annotation) =>
outputCol -> annotation.map { iAnnotation =>
val annotation = iAnnotation.asInstanceOf[Annotation]
val output = annotation.annotatorType match {
case AnnotatorType.WORD_EMBEDDINGS | AnnotatorType.SENTENCE_EMBEDDINGS
if parseEmbeddings =>
annotation.embeddings.mkString(" ")
case _ => annotation.result
}
if (streamer && getSeq2SeqOutputCols.contains(outputCol)) {
println(s"$output")
}
output
}
})
}
}

def annotate(
Expand Down Expand Up @@ -449,4 +459,10 @@ class LightPipeline(val pipelineModel: PipelineModel, parseEmbeddings: Boolean =
.asJava
}

private def getSeq2SeqOutputCols: Array[String] = {
getStages
.filter(stage => isSeq2SeqTransformer(stage))
.map(stage => stage.asInstanceOf[HasOutputAnnotationCol].getOutputCol)
}

}
5 changes: 5 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/annotator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,4 +705,9 @@ package object annotator {
object RoBertaForZeroShotClassification
extends ReadablePretrainedRoBertaForZeroShotModel
with ReadRoBertaForZeroShotDLModel

def isSeq2SeqTransformer(instance: Any): Boolean = {
instance.getClass.getPackage.getName == "com.johnsnowlabs.nlp.annotators.seq2seq"
}

}
Loading