diff --git a/.gitignore b/.gitignore index aa264460d82d79..e91a8952f7c8d5 100644 --- a/.gitignore +++ b/.gitignore @@ -338,3 +338,4 @@ python/docs/reference/_autosummary/** # MS Visio Code **/.vscode/ +.metals/ \ No newline at end of file diff --git a/examples/python/annotation/text/english/text-similarity/doc-sim-ranker/test_doc_sim_ranker.ipynb b/examples/python/annotation/text/english/text-similarity/doc-sim-ranker/test_doc_sim_ranker.ipynb new file mode 100644 index 00000000000000..eb77b388e42dc7 --- /dev/null +++ b/examples/python/annotation/text/english/text-similarity/doc-sim-ranker/test_doc_sim_ranker.ipynb @@ -0,0 +1,542 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "c3dc7ce5", + "metadata": {}, + "source": [ + "# Document Similarity Ranker for Spark NLP\n", + "### Efficient approximate nearest neighbor search on top of sentence embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1a9dd32e", + "metadata": {}, + "outputs": [], + "source": [ + "# Import Spark NLP classes\n", + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from sparknlp.pretrained import PretrainedPipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "82846deb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ":: loading settings :: url = jar:file:/Users/stefanolori/opt/anaconda3/envs/spknlp/lib/python3.8/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Ivy Default Cache set to: /Users/stefanolori/.ivy2/cache\n", + "The jars for the packages stored in: /Users/stefanolori/.ivy2/jars\n", + "com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency\n", + ":: resolving dependencies :: org.apache.spark#spark-submit-parent-d858c4fe-292f-4adf-8944-9ebef53c59cd;1.0\n", + "\tconfs: [default]\n", + "\tfound com.johnsnowlabs.nlp#spark-nlp_2.12;4.4.4 in local-ivy-cache\n", + "\tfound com.typesafe#config;1.4.2 in local-m2-cache\n", + "\tfound org.rocksdb#rocksdbjni;6.29.5 in central\n", + "\tfound com.amazonaws#aws-java-sdk-bundle;1.11.828 in central\n", + "\tfound com.github.universal-automata#liblevenshtein;3.0.0 in central\n", + "\tfound com.google.protobuf#protobuf-java-util;3.0.0-beta-3 in central\n", + "\tfound com.google.protobuf#protobuf-java;3.0.0-beta-3 in central\n", + "\tfound com.google.code.gson#gson;2.3 in central\n", + "\tfound it.unimi.dsi#fastutil;7.0.12 in central\n", + "\tfound org.projectlombok#lombok;1.16.8 in central\n", + "\tfound com.google.cloud#google-cloud-storage;2.16.0 in central\n", + "\tfound com.google.guava#guava;31.1-jre in central\n", + "\tfound com.google.guava#failureaccess;1.0.1 in central\n", + "\tfound com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central\n", + "\tfound com.google.errorprone#error_prone_annotations;2.16 in central\n", + "\tfound com.google.j2objc#j2objc-annotations;1.3 in central\n", + "\tfound com.google.http-client#google-http-client;1.42.3 in central\n", + "\tfound io.opencensus#opencensus-contrib-http-util;0.31.1 in central\n", + "\tfound com.google.http-client#google-http-client-jackson2;1.42.3 in central\n", + "\tfound com.google.http-client#google-http-client-gson;1.42.3 in central\n", + "\tfound com.google.api-client#google-api-client;2.1.1 in central\n", + "\tfound commons-codec#commons-codec;1.15 in central\n", + "\tfound com.google.oauth-client#google-oauth-client;1.34.1 in central\n", + "\tfound com.google.http-client#google-http-client-apache-v2;1.42.3 in central\n", + "\tfound com.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 in central\n", + "\tfound com.google.code.gson#gson;2.10 in central\n", + "\tfound com.google.cloud#google-cloud-core;2.9.0 in central\n", + "\tfound com.google.auto.value#auto-value-annotations;1.10.1 in central\n", + "\tfound com.google.cloud#google-cloud-core-http;2.9.0 in central\n", + "\tfound com.google.http-client#google-http-client-appengine;1.42.3 in central\n", + "\tfound com.google.api#gax-httpjson;0.105.1 in central\n", + "\tfound com.google.cloud#google-cloud-core-grpc;2.9.0 in central\n", + "\tfound io.grpc#grpc-core;1.51.0 in central\n", + "\tfound com.google.api#gax;2.20.1 in central\n", + "\tfound com.google.api#gax-grpc;2.20.1 in central\n", + "\tfound io.grpc#grpc-alts;1.51.0 in central\n", + "\tfound io.grpc#grpc-grpclb;1.51.0 in central\n", + "\tfound org.conscrypt#conscrypt-openjdk-uber;2.5.2 in central\n", + "\tfound io.grpc#grpc-protobuf;1.51.0 in central\n", + "\tfound com.google.auth#google-auth-library-credentials;1.13.0 in central\n", + "\tfound com.google.auth#google-auth-library-oauth2-http;1.13.0 in central\n", + "\tfound com.google.api#api-common;2.2.2 in central\n", + "\tfound javax.annotation#javax.annotation-api;1.3.2 in local-m2-cache\n", + "\tfound io.opencensus#opencensus-api;0.31.1 in central\n", + "\tfound io.grpc#grpc-context;1.51.0 in central\n", + "\tfound com.google.api.grpc#proto-google-iam-v1;1.6.22 in central\n", + "\tfound com.google.protobuf#protobuf-java;3.21.10 in central\n", + "\tfound com.google.protobuf#protobuf-java-util;3.21.10 in central\n", + "\tfound com.google.api.grpc#proto-google-common-protos;2.11.0 in central\n", + "\tfound org.threeten#threetenbp;1.6.4 in central\n", + "\tfound com.google.api.grpc#proto-google-cloud-storage-v2;2.16.0-alpha in central\n", + "\tfound com.google.api.grpc#grpc-google-cloud-storage-v2;2.16.0-alpha in central\n", + "\tfound com.google.api.grpc#gapic-google-cloud-storage-v2;2.16.0-alpha in central\n", + "\tfound com.fasterxml.jackson.core#jackson-core;2.14.1 in central\n", + "\tfound com.google.code.findbugs#jsr305;3.0.2 in central\n", + "\tfound io.grpc#grpc-api;1.51.0 in central\n", + "\tfound io.grpc#grpc-auth;1.51.0 in central\n", + "\tfound io.grpc#grpc-stub;1.51.0 in central\n", + "\tfound org.checkerframework#checker-qual;3.28.0 in central\n", + "\tfound com.google.api.grpc#grpc-google-iam-v1;1.6.22 in central\n", + "\tfound io.grpc#grpc-protobuf-lite;1.51.0 in central\n", + "\tfound com.google.android#annotations;4.1.1.4 in central\n", + "\tfound org.codehaus.mojo#animal-sniffer-annotations;1.22 in central\n", + "\tfound io.grpc#grpc-netty-shaded;1.51.0 in central\n", + "\tfound io.perfmark#perfmark-api;0.26.0 in central\n", + "\tfound io.grpc#grpc-googleapis;1.51.0 in central\n", + "\tfound io.grpc#grpc-xds;1.51.0 in central\n", + "\tfound io.opencensus#opencensus-proto;0.2.0 in central\n", + "\tfound io.grpc#grpc-services;1.51.0 in central\n", + "\tfound com.google.re2j#re2j;1.6 in central\n", + "\tfound com.navigamez#greex;1.0 in central\n", + "\tfound dk.brics.automaton#automaton;1.11-8 in central\n", + "\tfound com.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 in central\n", + ":: resolution report :: resolve 1092ms :: artifacts dl 43ms\n", + "\t:: modules in use:\n", + "\tcom.amazonaws#aws-java-sdk-bundle;1.11.828 from central in [default]\n", + "\tcom.fasterxml.jackson.core#jackson-core;2.14.1 from central in [default]\n", + "\tcom.github.universal-automata#liblevenshtein;3.0.0 from central in [default]\n", + "\tcom.google.android#annotations;4.1.1.4 from central in [default]\n", + "\tcom.google.api#api-common;2.2.2 from central in [default]\n", + "\tcom.google.api#gax;2.20.1 from central in [default]\n", + "\tcom.google.api#gax-grpc;2.20.1 from central in [default]\n", + "\tcom.google.api#gax-httpjson;0.105.1 from central in [default]\n", + "\tcom.google.api-client#google-api-client;2.1.1 from central in [default]\n", + "\tcom.google.api.grpc#gapic-google-cloud-storage-v2;2.16.0-alpha from central in [default]\n", + "\tcom.google.api.grpc#grpc-google-cloud-storage-v2;2.16.0-alpha from central in [default]\n", + "\tcom.google.api.grpc#grpc-google-iam-v1;1.6.22 from central in [default]\n", + "\tcom.google.api.grpc#proto-google-cloud-storage-v2;2.16.0-alpha from central in [default]\n", + "\tcom.google.api.grpc#proto-google-common-protos;2.11.0 from central in [default]\n", + "\tcom.google.api.grpc#proto-google-iam-v1;1.6.22 from central in [default]\n", + "\tcom.google.apis#google-api-services-storage;v1-rev20220705-2.0.0 from central in [default]\n", + "\tcom.google.auth#google-auth-library-credentials;1.13.0 from central in [default]\n", + "\tcom.google.auth#google-auth-library-oauth2-http;1.13.0 from central in [default]\n", + "\tcom.google.auto.value#auto-value-annotations;1.10.1 from central in [default]\n", + "\tcom.google.cloud#google-cloud-core;2.9.0 from central in [default]\n", + "\tcom.google.cloud#google-cloud-core-grpc;2.9.0 from central in [default]\n", + "\tcom.google.cloud#google-cloud-core-http;2.9.0 from central in [default]\n", + "\tcom.google.cloud#google-cloud-storage;2.16.0 from central in [default]\n", + "\tcom.google.code.findbugs#jsr305;3.0.2 from central in [default]\n", + "\tcom.google.code.gson#gson;2.10 from central in [default]\n", + "\tcom.google.errorprone#error_prone_annotations;2.16 from central in [default]\n", + "\tcom.google.guava#failureaccess;1.0.1 from central in [default]\n", + "\tcom.google.guava#guava;31.1-jre from central in [default]\n", + "\tcom.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava from central in [default]\n", + "\tcom.google.http-client#google-http-client;1.42.3 from central in [default]\n", + "\tcom.google.http-client#google-http-client-apache-v2;1.42.3 from central in [default]\n", + "\tcom.google.http-client#google-http-client-appengine;1.42.3 from central in [default]\n", + "\tcom.google.http-client#google-http-client-gson;1.42.3 from central in [default]\n", + "\tcom.google.http-client#google-http-client-jackson2;1.42.3 from central in [default]\n", + "\tcom.google.j2objc#j2objc-annotations;1.3 from central in [default]\n", + "\tcom.google.oauth-client#google-oauth-client;1.34.1 from central in [default]\n", + "\tcom.google.protobuf#protobuf-java;3.21.10 from central in [default]\n", + "\tcom.google.protobuf#protobuf-java-util;3.21.10 from central in [default]\n", + "\tcom.google.re2j#re2j;1.6 from central in [default]\n", + "\tcom.johnsnowlabs.nlp#spark-nlp_2.12;4.4.4 from local-ivy-cache in [default]\n", + "\tcom.johnsnowlabs.nlp#tensorflow-cpu_2.12;0.4.4 from central in [default]\n", + "\tcom.navigamez#greex;1.0 from central in [default]\n", + "\tcom.typesafe#config;1.4.2 from local-m2-cache in [default]\n", + "\tcommons-codec#commons-codec;1.15 from central in [default]\n", + "\tdk.brics.automaton#automaton;1.11-8 from central in [default]\n", + "\tio.grpc#grpc-alts;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-api;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-auth;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-context;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-core;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-googleapis;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-grpclb;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-netty-shaded;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-protobuf;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-protobuf-lite;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-services;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-stub;1.51.0 from central in [default]\n", + "\tio.grpc#grpc-xds;1.51.0 from central in [default]\n", + "\tio.opencensus#opencensus-api;0.31.1 from central in [default]\n", + "\tio.opencensus#opencensus-contrib-http-util;0.31.1 from central in [default]\n", + "\tio.opencensus#opencensus-proto;0.2.0 from central in [default]\n", + "\tio.perfmark#perfmark-api;0.26.0 from central in [default]\n", + "\tit.unimi.dsi#fastutil;7.0.12 from central in [default]\n", + "\tjavax.annotation#javax.annotation-api;1.3.2 from local-m2-cache in [default]\n", + "\torg.checkerframework#checker-qual;3.28.0 from central in [default]\n", + "\torg.codehaus.mojo#animal-sniffer-annotations;1.22 from central in [default]\n", + "\torg.conscrypt#conscrypt-openjdk-uber;2.5.2 from central in [default]\n", + "\torg.projectlombok#lombok;1.16.8 from central in [default]\n", + "\torg.rocksdb#rocksdbjni;6.29.5 from central in [default]\n", + "\torg.threeten#threetenbp;1.6.4 from central in [default]\n", + "\t:: evicted modules:\n", + "\tcom.google.protobuf#protobuf-java-util;3.0.0-beta-3 by [com.google.protobuf#protobuf-java-util;3.21.10] in [default]\n", + "\tcom.google.protobuf#protobuf-java;3.0.0-beta-3 by [com.google.protobuf#protobuf-java;3.21.10] in [default]\n", + "\tcom.google.code.gson#gson;2.3 by [com.google.code.gson#gson;2.10] in [default]\n", + "\t---------------------------------------------------------------------\n", + "\t| | modules || artifacts |\n", + "\t| conf | number| search|dwnlded|evicted|| number|dwnlded|\n", + "\t---------------------------------------------------------------------\n", + "\t| default | 73 | 0 | 0 | 3 || 70 | 0 |\n", + "\t---------------------------------------------------------------------\n", + ":: retrieving :: org.apache.spark#spark-submit-parent-d858c4fe-292f-4adf-8944-9ebef53c59cd\n", + "\tconfs: [default]\n", + "\t0 artifacts copied, 70 already retrieved (0kB/16ms)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23/07/01 22:00:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting default log level to \"WARN\".\n", + "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n" + ] + } + ], + "source": [ + "# Create the PySpark session\n", + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder \\\n", + " .appName(\"Spark NLP\")\\\n", + " .master(\"local[*]\")\\\n", + " .config(\"spark.driver.memory\",\"16G\")\\\n", + " .config(\"spark.driver.maxResultSize\", \"0\") \\\n", + " .config(\"spark.kryoserializer.buffer.max\", \"2000M\")\\\n", + " .config(\"spark.jars.packages\", \"com.johnsnowlabs.nlp:spark-nlp_2.12:5.0.0\")\\\n", + " .getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a3f563d5", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's use some dataset where we can visually control similarity\n", + "# Documents are coupled, as 1-2, 3-4, 5-6, 7-8 and they were voluntarily created similar\n", + "data = spark.createDataFrame(\n", + " [\n", + " [\"First document, this is my first sentence. This is my second sentence.\"],\n", + " [\"Second document, this is my second sentence. This is my second sentence.\"],\n", + " [\"Third document, climate change is arguably one of the most pressing problems of our time.\"],\n", + " [\"Fourth document, climate change is definitely one of the most pressing problems of our time.\"],\n", + " [\"Fifth document, Florence in Italy, is among the most beautiful cities in Europe.\"],\n", + " [\"Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France.\"],\n", + " [\"Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.\"],\n", + " [\"Eighth document, the warmest place in France is the French Riviera coast in Southern France.\"]\n", + " ]\n", + " ).toDF(\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "34604126", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Stage 0:> (0 + 1) / 1]\r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------------------------------------------------------------------------------------------------+\n", + "|text |\n", + "+------------------------------------------------------------------------------------------------------+\n", + "|First document, this is my first sentence. This is my second sentence. |\n", + "|Second document, this is my second sentence. This is my second sentence. |\n", + "|Third document, climate change is arguably one of the most pressing problems of our time. |\n", + "|Fourth document, climate change is definitely one of the most pressing problems of our time. |\n", + "|Fifth document, Florence in Italy, is among the most beautiful cities in Europe. |\n", + "|Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France. |\n", + "|Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.|\n", + "|Eighth document, the warmest place in France is the French Riviera coast in Southern France. |\n", + "+------------------------------------------------------------------------------------------------------+\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "data.show(10, False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "945e787d", + "metadata": {}, + "source": [ + "## A document similarity ranker pipeline\n", + "### The document similarity ranker works downstream of other annotators generating sentence embeddings. In this example we'll use RoBertaSentenceEmbeddings.\n", + "The pipeline will use the following steps:\n", + "- document_assembler to annotate the documents\n", + "- sentence_detector to detect sentences\n", + "- tokenizer to apply tokenization\n", + "- sentence_embeddings to created the necessary sentence embeddings representation\n", + "- document_similarity_ranker to extract the simlar documents via annotator configuration\n", + "- document_similarity_ranker_finisher to extract the column of interest for this new annotator" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d4d2bd1d", + "metadata": {}, + "source": [ + "## DocumentSimilarityRankerApproach: input parameter setters overview\n", + "- setInputCols(\"sentence_embeddings\") : this setter will address input column\n", + "- setOutputCol(\"doc_similarity_rankings\") : this setter will address ouput column\n", + "- setSimilarityMethod(\"brp\") : this setter will select the LSH method (lsh|mh) used to apply approximate nearest neigbours search\n", + "- setNumberOfNeighbours(10) : this setter will address the desired number of similar documents for a given document in the set\n", + "- setBucketLength(2.0) : LSH parameter used to control the average size of hash buckets and improve recall\n", + "- setNumHashTables(3) : LSH parameter used to control number of hash tables used in LSH OR-amplification and improve recall\n", + "- setVisibleDistances(True) : this setter will make distances visible in the result, useful for debugging level information\n", + "- setIdentityRanking(False) : this setter will make identity distance (0.0) visible, useful for debugging level information" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0b36d5cd", + "metadata": {}, + "source": [ + "## DocumentSimilarityRankerFinisher: out parameters overview\n", + "- setInputCols(\"doc_similarity_rankings\") : this setter will read the result column to extract IDs and distances\n", + "- setOutputCols(\n", + " \"finished_doc_similarity_rankings_id\",\n", + " \"finished_doc_similarity_rankings_neighbors\") : this setter selects the column with the document query ID and the neighbors document that results from the search run" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9a8f9eae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sent_roberta_base download started this may take some time.\n", + "Approximate size to download 284.8 MB\n", + "[ | ]sent_roberta_base download started this may take some time.\n", + "Approximate size to download 284.8 MB\n", + "Download done! Loading the resource.\n", + "[ / ]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-07-01 22:01:11.233544: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ \\ ]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: An illegal reflective access operation has occurred\n", + "WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/Users/stefanolori/opt/anaconda3/envs/spknlp/lib/python3.8/site-packages/pyspark/jars/spark-core_2.12-3.3.1.jar) to field java.lang.ref.Reference.referent\n", + "WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$\n", + "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n", + "WARNING: All illegal access operations will be denied in a future release\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[OK!]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23/07/01 22:01:22 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS\n", + "23/07/01 22:01:22 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----------------------------------+------------------------------------------+\n", + "|finished_doc_similarity_rankings_id|finished_doc_similarity_rankings_neighbors|\n", + "+-----------------------------------+------------------------------------------+\n", + "|1510101612 |[(1634839239,0.12448559273510636)] |\n", + "|1634839239 |[(1510101612,0.12448559273510636)] |\n", + "|-612640902 |[(1274183715,0.12201215887654807)] |\n", + "|1274183715 |[(-612640902,0.12201215887654807)] |\n", + "|-1320876223 |[(1293373212,0.17848861258809434)] |\n", + "|1293373212 |[(-1320876223,0.17848861258809434)] |\n", + "|-1548374770 |[(-1719102856,0.2329717161223739)] |\n", + "|-1719102856 |[(-1548374770,0.2329717161223739)] |\n", + "+-----------------------------------+------------------------------------------+\n", + "\n" + ] + } + ], + "source": [ + "from sparknlp.annotator.similarity.document_similarity_ranker import *\n", + "\n", + "document_assembler = DocumentAssembler() \\\n", + " .setInputCol(\"text\") \\\n", + " .setOutputCol(\"document\")\n", + "sentence_detector = SentenceDetector() \\\n", + " .setInputCols([\"document\"]) \\\n", + " .setOutputCol(\"sentence\")\n", + "tokenizer = Tokenizer() \\\n", + " .setInputCols([\"sentence\"]) \\\n", + " .setOutputCol(\"token\")\n", + "\n", + "sentence_embeddings = RoBertaSentenceEmbeddings.pretrained() \\\n", + " .setInputCols([\"document\"]) \\\n", + " .setOutputCol(\"sentence_embeddings\")\n", + "\n", + "document_similarity_ranker = DocumentSimilarityRankerApproach() \\\n", + " .setInputCols(\"sentence_embeddings\") \\\n", + " .setOutputCol(\"doc_similarity_rankings\") \\\n", + " .setSimilarityMethod(\"brp\") \\\n", + " .setNumberOfNeighbours(1) \\\n", + " .setBucketLength(2.0) \\\n", + " .setNumHashTables(3) \\\n", + " .setVisibleDistances(True) \\\n", + " .setIdentityRanking(False)\n", + "\n", + "document_similarity_ranker_finisher = DocumentSimilarityRankerFinisher() \\\n", + " .setInputCols(\"doc_similarity_rankings\") \\\n", + " .setOutputCols(\n", + " \"finished_doc_similarity_rankings_id\",\n", + " \"finished_doc_similarity_rankings_neighbors\") \\\n", + " .setExtractNearestNeighbor(True)\n", + "\n", + "pipeline = Pipeline(stages=[\n", + " document_assembler,\n", + " sentence_detector,\n", + " tokenizer,\n", + " sentence_embeddings,\n", + " document_similarity_ranker,\n", + " document_similarity_ranker_finisher\n", + " ])\n", + "\n", + "docSimRankerPipeline = pipeline.fit(data).transform(data)\n", + "# TODO add write/read pipeline\n", + "(\n", + " docSimRankerPipeline\n", + " .select(\n", + " \"finished_doc_similarity_rankings_id\",\n", + " \"finished_doc_similarity_rankings_neighbors\"\n", + " ).show(10, False)\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "54eca293", + "metadata": {}, + "source": [ + "## Result analysis for consistent result confirmation\n", + "#### The test is asserting the initial hypothesis. The documents were created similar in pair: 1-2, 3-4, 5-6, 7-8.\n", + "For instance document 1 and 2 are detected mutually best neighbors at the very same distance respectively:\n", + "- document ID 1510101612 has his best similar document in (1634839239,0.12448559273510636) at distance 0.12448559273510636\n", + "- document ID 1634839239 has his best similar document in (1510101612,0.12448559273510636) at distance 0.12448559273510636\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cde88af", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/sparknlp/annotator/similarity/__init__.py b/python/sparknlp/annotator/similarity/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/python/sparknlp/annotator/similarity/document_similarity_ranker.py b/python/sparknlp/annotator/similarity/document_similarity_ranker.py new file mode 100644 index 00000000000000..00ba0738be2936 --- /dev/null +++ b/python/sparknlp/annotator/similarity/document_similarity_ranker.py @@ -0,0 +1,232 @@ +# Copyright 2017-2023 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for DocumentSimilarityRanker.""" + +from sparknlp.common import * +from pyspark import keyword_only +from pyspark.ml.param import TypeConverters, Params, Param +from sparknlp.internal import AnnotatorTransformer + + +class DocumentSimilarityRankerApproach(AnnotatorApproach, HasEnableCachingProperties): + inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS] + + outputAnnotatorType = AnnotatorType.DOC_SIMILARITY_RANKINGS + + similarityMethod = Param(Params._dummy(), + "similarityMethod", + "The similarity method used to calculate the neighbours. (Default: 'brp', " + "Bucketed Random Projection for Euclidean Distance)", + typeConverter=TypeConverters.toString) + + numberOfNeighbours = Param(Params._dummy(), + "numberOfNeighbours", + "The number of neighbours the model will return (Default:`10`)", + typeConverter=TypeConverters.toInt) + + bucketLength = Param(Params._dummy(), + "bucketLength", + "The bucket length that controls the average size of hash buckets. " + "A larger bucket length (i.e., fewer buckets) increases the probability of features " + "being hashed to the same bucket (increasing the numbers of true and false positives).", + typeConverter=TypeConverters.toFloat) + + numHashTables = Param(Params._dummy(), + "numHashTables", + "number of hash tables, where increasing number of hash tables lowers the " + "false negative rate,and decreasing it improves the running performance.", + typeConverter=TypeConverters.toInt) + + visibleDistances = Param(Params._dummy(), + "visibleDistances", + "Whether to set visibleDistances in ranking output (Default: `false`).", + typeConverter=TypeConverters.toBoolean) + + identityRanking = Param(Params._dummy(), + "identityRanking", + "Whether to include identity in ranking result set. Useful for debug. (Default: `false`).", + typeConverter=TypeConverters.toBoolean) + + def setSimilarityMethod(self, value): + """Sets the similarity method used to calculate the neighbours. + (Default: `"brp"`, Bucketed Random Projection for Euclidean Distance) + + Parameters + ---------- + value : str + the similarity method to calculate the neighbours. + """ + return self._set(similarityMethod=value) + + def setNumberOfNeighbours(self, value): + """Sets The number of neighbours the model will return for each document(Default:`"10"`). + + Parameters + ---------- + value : str + the number of neighbours the model will return for each document. + """ + return self._set(numberOfNeighbours=value) + + def setBucketLength(self, value): + """Sets the bucket length that controls the average size of hash buckets (Default:`"2.0"`). + + Parameters + ---------- + value : float + Sets the bucket length that controls the average size of hash buckets. + """ + return self._set(bucketLength=value) + + def setNumHashTables(self, value): + """Sets the number of hash tables. + + Parameters + ---------- + value : int + Sets the number of hash tables. + """ + return self._set(numHashTables=value) + + def setVisibleDistances(self, value): + """Sets the document distances visible in the result set. + + Parameters + ---------- + value : bool + Sets the document distances visible in the result set. + Default('False') + """ + return self._set(visibleDistances=value) + + def setIdentityRanking(self, value): + """Sets the document identity ranking inclusive in the result set. + + Parameters + ---------- + value : bool + Sets the document identity ranking inclusive in the result set. + Useful for debugging. + Default('False'). + """ + return self._set(identityRanking=value) + + @keyword_only + def __init__(self): + super(DocumentSimilarityRankerApproach, self)\ + .__init__(classname="com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach") + self._setDefault( + similarityMethod="brp", + numberOfNeighbours=10, + bucketLength=2.0, + numHashTables=3, + visibleDistances=False, + identityRanking=False + ) + + def _create_model(self, java_model): + return DocumentSimilarityRankerModel(java_model=java_model) + + +class DocumentSimilarityRankerModel(AnnotatorModel, HasEmbeddingsProperties): + + name = "DocumentSimilarityRankerModel" + inputAnnotatorTypes = [AnnotatorType.SENTENCE_EMBEDDINGS] + outputAnnotatorType = AnnotatorType.DOC_SIMILARITY_RANKINGS + + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerModel", + java_model=None): + super(DocumentSimilarityRankerModel, self).__init__( + classname=classname, + java_model=java_model + ) + + +class DocumentSimilarityRankerFinisher(AnnotatorTransformer): + + inputCols = Param(Params._dummy(), + "inputCols", + "name of input annotation cols containing document similarity ranker results", + typeConverter=TypeConverters.toListString) + outputCols = Param(Params._dummy(), + "outputCols", + "output DocumentSimilarityRankerFinisher output cols", + typeConverter=TypeConverters.toListString) + extractNearestNeighbor = Param(Params._dummy(), "extractNearestNeighbor", + "whether to extract the nearest neighbor document", + typeConverter=TypeConverters.toBoolean) + + name = "DocumentSimilarityRankerFinisher" + + @keyword_only + def __init__(self): + super(DocumentSimilarityRankerFinisher, self).__init__(classname="com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher") + self._setDefault( + extractNearestNeighbor=False + ) + + @keyword_only + def setParams(self): + kwargs = self._input_kwargs + return self._set(**kwargs) + + def setInputCols(self, *value): + """Sets name of input annotation columns containing embeddings. + + Parameters + ---------- + *value : str + Input columns for the annotator + """ + + if len(value) == 1 and type(value[0]) == list: + return self._set(inputCols=value[0]) + else: + return self._set(inputCols=list(value)) + + def setOutputCols(self, *value): + """Sets names of finished output columns. + + Parameters + ---------- + *value : List[str] + Input columns for the annotator + """ + + if len(value) == 1 and type(value[0]) == list: + return self._set(outputCols=value[0]) + else: + return self._set(outputCols=list(value)) + + def setExtractNearestNeighbor(self, value): + """Sets whether to extract the nearest neighbor document, by default False. + + Parameters + ---------- + value : bool + Whether to extract the nearest neighbor document + """ + + return self._set(extractNearestNeighbor=value) + + def getInputCols(self): + """Gets input columns name of annotations.""" + return self.getOrDefault(self.inputCols) + + def getOutputCols(self): + """Gets output columns name of annotations.""" + if len(self.getOrDefault(self.outputCols)) == 0: + return ["finished_" + input_col for input_col in self.getInputCols()] + else: + return self.getOrDefault(self.outputCols) \ No newline at end of file diff --git a/python/sparknlp/common/annotator_type.py b/python/sparknlp/common/annotator_type.py index 2d0eb1ed54c9e8..0cd230a5ec480d 100644 --- a/python/sparknlp/common/annotator_type.py +++ b/python/sparknlp/common/annotator_type.py @@ -35,3 +35,4 @@ class AnnotatorType(object): NODE = "node" TABLE = "table" DUMMY = "dummy" + DOC_SIMILARITY_RANKINGS = "doc_similarity_rankings" diff --git a/python/test/annotator/similarity/__init__.py b/python/test/annotator/similarity/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/python/test/annotator/similarity/doc_similarity_ranker_test.py b/python/test/annotator/similarity/doc_similarity_ranker_test.py new file mode 100644 index 00000000000000..f9a93f4d12ee2d --- /dev/null +++ b/python/test/annotator/similarity/doc_similarity_ranker_test.py @@ -0,0 +1,90 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.annotator.similarity.document_similarity_ranker import * +from sparknlp.base import * +from test.util import SparkSessionForTest + + +@pytest.mark.slow +class DocumentSimilarityRankerTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkSessionForTest.spark + + self.data = SparkSessionForTest.spark.createDataFrame([ + ["First document, this is my first sentence. This is my second sentence."], + ["Second document, this is my second sentence. This is my second sentence."], + ["Third document, climate change is arguably one of the most pressing problems of our time."], + ["Fourth document, climate change is definitely one of the most pressing problems of our time."], + ["Fifth document, Florence in Italy, is among the most beautiful cities in Europe."], + ["Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France."], + ["Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France."], + ["Eighth document, the warmest place in France is the French Riviera coast in Southern France."] + ]).toDF("text") + + def runTest(self): + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("document") + sentence_detector = SentenceDetector() \ + .setInputCols(["document"]) \ + .setOutputCol("sentence") + tokenizer = Tokenizer() \ + .setInputCols(["sentence"]) \ + .setOutputCol("token") + + sentence_embeddings = RoBertaSentenceEmbeddings.pretrained() \ + .setInputCols(["document"]) \ + .setOutputCol("sentence_embeddings") + + document_similarity_ranker = DocumentSimilarityRankerApproach() \ + .setInputCols("sentence_embeddings") \ + .setOutputCol("doc_similarity_rankings") \ + .setSimilarityMethod("brp") \ + .setNumberOfNeighbours(10) \ + .setBucketLength(2.0) \ + .setNumHashTables(3) \ + .setVisibleDistances(True) \ + .setIdentityRanking(True) + + document_similarity_ranker_finisher = DocumentSimilarityRankerFinisher() \ + .setInputCols("doc_similarity_rankings") \ + .setOutputCols( + "finished_doc_similarity_rankings_id", + "finished_doc_similarity_rankings_neighbors") \ + .setExtractNearestNeighbor(True) + + pipeline = Pipeline(stages=[ + document_assembler, + sentence_detector, + tokenizer, + sentence_embeddings, + document_similarity_ranker, + document_similarity_ranker_finisher + ]) + + model = pipeline.fit(self.data) + + ( + model + .transform(self.data) + .select("text", + "finished_doc_similarity_rankings_id", + "finished_doc_similarity_rankings_neighbors") + .show(10, False) + ) \ No newline at end of file diff --git a/src/main/scala/com/johnsnowlabs/nlp/AnnotatorType.scala b/src/main/scala/com/johnsnowlabs/nlp/AnnotatorType.scala index 6a51a15b9e83cd..7e420f7f65eb43 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/AnnotatorType.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/AnnotatorType.scala @@ -38,5 +38,5 @@ object AnnotatorType { val NODE = "node" val TABLE = "table" val DUMMY = "dummy" - + val DOC_SIMILARITY_RANKINGS = "doc_similarity_rankings" } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerApproach.scala new file mode 100644 index 00000000000000..1282303c995815 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerApproach.scala @@ -0,0 +1,223 @@ +package com.johnsnowlabs.nlp.annotators.similarity + +import com.johnsnowlabs.nlp.AnnotatorType.{DOC_SIMILARITY_RANKINGS, SENTENCE_EMBEDDINGS} +import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties} +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.ml.PipelineModel +import org.apache.spark.ml.feature.{ + BucketedRandomProjectionLSH, + BucketedRandomProjectionLSHModel, + MinHashLSH +} +import org.apache.spark.ml.functions.array_to_vector +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.param.{BooleanParam, Param} +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} +import org.apache.spark.sql.functions.{col, flatten, udf} +import org.apache.spark.sql.{DataFrame, Dataset} + +import scala.util.hashing.MurmurHash3 + +sealed trait NeighborAnnotation { + def neighbors: Array[_] +} + +case class IndexedNeighbors(neighbors: Array[Int]) extends NeighborAnnotation + +case class IndexedNeighborsWithDistance(neighbors: Array[(Int, Double)]) + extends NeighborAnnotation + +case class NeighborsResultSet(result: (Int, NeighborAnnotation)) + +class DocumentSimilarityRankerApproach(override val uid: String) + extends AnnotatorApproach[DocumentSimilarityRankerModel] + with HasEnableCachingProperties { + + override val description: AnnotatorType = "LSH based document similarity annotator" + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("DocumentSimilarityRankerApproach")) + + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(SENTENCE_EMBEDDINGS) + + override val outputAnnotatorType: AnnotatorType = DOC_SIMILARITY_RANKINGS + + val LSH_INPUT_COL_NAME = "features" + + val LSH_OUTPUT_COL_NAME = "hashes" + + val INDEX_COL_NAME = "index" + + val DISTANCE = "distCol" + + val INPUT_EMBEDDINGS = "sentence_embeddings.embeddings" + + val TEXT = "text" + + /** The similarity method used to calculate the neighbours. (Default: `"brp"`, Bucketed Random + * Projection for Euclidean Distance) + * + * @group param + */ + val similarityMethod = new Param[String]( + this, + "similarityMethod", + """The similarity method used to calculate the neighbours. + |(Default: `"brp"`, Bucketed Random Projection for Euclidean Distance) + |""".stripMargin) + + def setSimilarityMethod(value: String): this.type = set(similarityMethod, value) + + def getSimilarityMethod: String = $(similarityMethod) + + /** The number of neighbours the model will return (Default:`"10"`). + * + * @group param + */ + val numberOfNeighbours = new Param[Int]( + this, + "numberOfNeighbours", + """The number of neighbours the model will return for each document (Default:`"10"`)""") + + def setNumberOfNeighbours(value: Int): this.type = set(numberOfNeighbours, value) + + def getNumberOfNeighbours: Int = $(numberOfNeighbours) + + val bucketLength = new Param[Double]( + this, + "bucketLength", + """The bucket length that controls the average size of hash buckets. + |A larger bucket length (i.e., fewer buckets) increases the probability of features being hashed + |to the same bucket (increasing the numbers of true and false positives) + |""".stripMargin) + + def setBucketLength(value: Double): this.type = set(bucketLength, value) + + def getBucketLength: Double = $(bucketLength) + + val numHashTables = new Param[Int]( + this, + "numHashTables", + """number of hash tables, where increasing number of hash tables lowers the false negative rate, + |and decreasing it improves the running performance. + |""".stripMargin) + + def setNumHashTables(value: Int): this.type = set(numHashTables, value) + + val visibleDistances = new BooleanParam( + this, + "visibleDistances", + "Whether to set visibleDistances in ranking output (Default: `false`)") + + def setVisibleDistances(value: Boolean): this.type = set(visibleDistances, value) + + def getVisibleDistances: Boolean = $(visibleDistances) + + val identityRanking = new BooleanParam( + this, + "identityRanking", + "Whether to include identity in ranking result set. Useful for debug. (Default: `false`)") + + def setIdentityRanking(value: Boolean): this.type = set(identityRanking, value) + + def getIdentityRanking: Boolean = $(identityRanking) + + setDefault( + similarityMethod -> "brp", + numberOfNeighbours -> 10, + bucketLength -> 2.0, + numHashTables -> 3, + visibleDistances -> false, + identityRanking -> false) + + def getNeighborsResultSet( + query: (Int, Vector), + similarityDataset: DataFrame): NeighborsResultSet = { + + val lsh = $(similarityMethod) match { + case "brp" => + new BucketedRandomProjectionLSH() + .setBucketLength($(bucketLength)) + .setNumHashTables($(numHashTables)) + .setInputCol(LSH_INPUT_COL_NAME) + .setOutputCol(LSH_OUTPUT_COL_NAME) + case "mh" => + new MinHashLSH() + .setNumHashTables($(numHashTables)) + .setInputCol(LSH_INPUT_COL_NAME) + .setOutputCol(LSH_OUTPUT_COL_NAME) + case _ => + throw new IllegalArgumentException(s"${$(similarityMethod)} is not a valid value.") + } + + val model = lsh.fit(similarityDataset) + + query match { + case (index, queryVector) => + val _similarityDataset = + if (getIdentityRanking) { + similarityDataset + } else { + similarityDataset.where(col("index") =!= index) + } + + val similarRankedDocs = + model.approxNearestNeighbors(_similarityDataset, queryVector, getNumberOfNeighbours) + + if (getVisibleDistances) { + val rankedNeighboursWithDistances = similarRankedDocs + .select(INDEX_COL_NAME, DISTANCE) + .collect() + .map(row => (row.getInt(0), row.getDouble(1))) + + NeighborsResultSet((index, IndexedNeighborsWithDistance(rankedNeighboursWithDistances))) + } else { + val rankedNeighbours = similarRankedDocs + .select(INDEX_COL_NAME) + .collect() + .map(_.getInt(0)) + + NeighborsResultSet(index, IndexedNeighbors(rankedNeighbours)) + } + case _ => throw new IllegalArgumentException("query is not of type (Int, DenseVector)") + } + } + + override def train( + dataset: Dataset[_], + recursivePipeline: Option[PipelineModel]): DocumentSimilarityRankerModel = { + + val embeddingsDataset = dataset.withColumn(LSH_INPUT_COL_NAME, col(INPUT_EMBEDDINGS)) + + val similarityDataset: DataFrame = embeddingsDataset + .withColumn(s"$LSH_INPUT_COL_NAME", flatten(col(s"$LSH_INPUT_COL_NAME"))) + .withColumn(s"$LSH_INPUT_COL_NAME", array_to_vector(col(s"$LSH_INPUT_COL_NAME"))) + + val mh3UDF = udf { (s: String) => MurmurHash3.stringHash(s, MurmurHash3.stringSeed) } + + val similarityDatasetWithIndex = + similarityDataset.withColumn(INDEX_COL_NAME, mh3UDF(col(TEXT))) + + val indexedVectorTuples = similarityDatasetWithIndex + .select(INDEX_COL_NAME, LSH_INPUT_COL_NAME) + .rdd + .map(x => (x.getAs[Int](INDEX_COL_NAME), x.getAs[Vector](LSH_INPUT_COL_NAME))) + .collect() + + val similarityMappings: Map[Int, NeighborAnnotation] = indexedVectorTuples + .map(query => getNeighborsResultSet(query, similarityDatasetWithIndex)) + .map(_.result) + .toMap + + new DocumentSimilarityRankerModel() + .setSimilarityMappings(Map("similarityMappings" -> similarityMappings)) + } +} + +/** This is the companion object of [[DocumentSimilarityRankerApproach]]. Please refer to that + * class for the documentation. + */ +object DocumentSimilarityRankerApproach + extends DefaultParamsReadable[DocumentSimilarityRankerApproach] diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerModel.scala new file mode 100644 index 00000000000000..eb75d78c7df430 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/similarity/DocumentSimilarityRankerModel.scala @@ -0,0 +1,78 @@ +package com.johnsnowlabs.nlp.annotators.similarity + +import com.johnsnowlabs.nlp.AnnotatorType.{DOC_SIMILARITY_RANKINGS, SENTENCE_EMBEDDINGS} +import com.johnsnowlabs.nlp.embeddings.HasEmbeddingsProperties +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.nlp.{ + Annotation, + AnnotatorModel, + HasSimpleAnnotate, + ParamsAndFeaturesReadable, + ParamsAndFeaturesWritable +} +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} +import org.apache.spark.sql.functions.col + +import scala.util.hashing.MurmurHash3 + +class DocumentSimilarityRankerModel(override val uid: String) + extends AnnotatorModel[DocumentSimilarityRankerModel] + with HasSimpleAnnotate[DocumentSimilarityRankerModel] + with HasEmbeddingsProperties + with ParamsAndFeaturesWritable { + + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(SENTENCE_EMBEDDINGS) + + override val outputAnnotatorType: AnnotatorType = DOC_SIMILARITY_RANKINGS + + def this() = this(Identifiable.randomUID("DOC_SIMILARITY_RANKER")) + + /** Dictionary of words with their vectors + * + * @group param + */ + val similarityMappings: MapFeature[String, Map[Int, NeighborAnnotation]] = + new MapFeature(this, "similarityMappings") + + /** @group setParam */ + def setSimilarityMappings(value: Map[String, Map[Int, NeighborAnnotation]]): this.type = + set(similarityMappings, value) + + def getSimilarityMappings: Map[Int, NeighborAnnotation] = + $$(similarityMappings).getOrElse("similarityMappings", Map.empty) + + setDefault(inputCols -> Array(SENTENCE_EMBEDDINGS), outputCol -> DOC_SIMILARITY_RANKINGS) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param annotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def annotate(annotations: Seq[Annotation]): Seq[Annotation] = + annotations.map(annotation => { + val inputResult = annotation.result + val targetIndex = MurmurHash3.stringHash(inputResult, MurmurHash3.stringSeed) + val neighborsAnnotation: NeighborAnnotation = + getSimilarityMappings.getOrElse(targetIndex, IndexedNeighbors(Array.empty)) // index NA + + Annotation( + annotatorType = outputAnnotatorType, + begin = annotation.begin, + end = annotation.end, + result = annotation.result, + metadata = annotation.metadata + + ("lshId" -> targetIndex.toString) + + ("lshNeighbors" -> neighborsAnnotation.neighbors.mkString("[", ",", "]")), + embeddings = annotation.embeddings) + }) +} + +trait ReadableDocumentSimilarityRanker + extends ParamsAndFeaturesReadable[DocumentSimilarityRankerModel] + +object DocumentSimilarityRankerModel extends ReadableDocumentSimilarityRanker diff --git a/src/main/scala/com/johnsnowlabs/nlp/finisher/DocumentSimilarityRankerFinisher.scala b/src/main/scala/com/johnsnowlabs/nlp/finisher/DocumentSimilarityRankerFinisher.scala new file mode 100644 index 00000000000000..3aeb7ccb9dd29b --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/finisher/DocumentSimilarityRankerFinisher.scala @@ -0,0 +1,181 @@ +package com.johnsnowlabs.nlp.finisher + +import com.johnsnowlabs.nlp.AnnotatorType +import com.johnsnowlabs.nlp.util.FinisherUtil +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.{DataFrame, Dataset} + +case class DocumentSimilarityRankerFinisher(override val uid: String) + extends Transformer + with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("DOCUMENT_SIMILARITY_RANKER_FINISHER")) + + val LSH_ID_COL_NAME = "lshId" + + val LSH_NEIGHBORS_COL_NAME = "lshNeighbors" + + val FINISHED_DOC_SIM_RANKER_ID_DEFAULT = "finished_doc_similarity_rankings_id" + + val FINISHED_DOC_SIM_RANKER_NEIGHBORS_DEFAULT = "finished_doc_similarity_rankings_neighbors" + + /** Name of input annotation cols containing embeddings + * + * @group param + */ + val inputCols: StringArrayParam = + new StringArrayParam( + this, + "inputCols", + "Name of input annotation cols containing similar documents") + + /** Name of input annotation cols containing similar documents + * + * @group setParam + */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** Name of input annotation cols containing similar documents + * + * @group setParam + */ + def setInputCols(value: String*): this.type = setInputCols(value.toArray) + + /** Name of DocumentSimilarityRankerFinisher output cols + * + * @group getParam + */ + def getInputCols: Array[String] = $(inputCols) + + /** Name of DocumentSimilarityRankerFinisher output cols + * + * @group param + */ + val outputCols: StringArrayParam = + new StringArrayParam( + this, + "outputCols", + "Name of DocumentSimilarityRankerFinisher output cols") + + /** Name of DocumentSimilarityRankerFinisher output cols + * + * @group setParam + */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** Name of DocumentSimilarityRankerFinisher output cols + * + * @group setParam + */ + def setOutputCols(value: String*): this.type = setOutputCols(value.toArray) + + /** Name of input annotation cols containing embeddings + * + * @group getParam + */ + def getOutputCols: Array[String] = get(outputCols).getOrElse(getInputCols.map("finished_" + _)) + + val extractNearestNeighbor: BooleanParam = + new BooleanParam( + this, + "extractNearestNeighbor", + doc = "Extract the best neighbors with distance") + + /** Set flag to extract best neighbor with distance + * + * @group setParam + */ + def setExtractNearestNeighbor(value: Boolean): this.type = set(extractNearestNeighbor, value) + + /** Name of input annotation cols containing embeddings + * + * @group getParam + */ + def getExtractNearestNeighbor: Boolean = $(extractNearestNeighbor) + + setDefault(extractNearestNeighbor -> false) + + override def transform(dataset: Dataset[_]): DataFrame = { + + require( + getOutputCols.length == 1 || getOutputCols.length == 2, + "Output column array should have length 1 (default case) or 2 when value id and neighbors are assigned.") + + val (idColName, neighborsColName) = + getOutputCols.length match { + case 1 => (FINISHED_DOC_SIM_RANKER_ID_DEFAULT, FINISHED_DOC_SIM_RANKER_NEIGHBORS_DEFAULT) + case 2 => (getOutputCols(0), getOutputCols(1)) + } + + val transformed = dataset + .withColumn( + idColName, + element_at(col(s"${AnnotatorType.DOC_SIMILARITY_RANKINGS}.metadata"), 1) + .getItem(LSH_ID_COL_NAME) + .cast("int")) + .withColumn( + neighborsColName, + element_at(col(s"${AnnotatorType.DOC_SIMILARITY_RANKINGS}.metadata"), 1) + .getItem(LSH_NEIGHBORS_COL_NAME)) + + val formatted = transformed + .withColumn( + s"no_squared_$neighborsColName", + regexp_replace(col(neighborsColName), "[\\[\\]]", "")) + .withColumn( + s"tuple_extract_$neighborsColName", + regexp_extract(col(s"no_squared_$neighborsColName"), "\\((.*?)\\)", 0)) + .withColumn( + s"no_rounded_$neighborsColName", + regexp_replace(col(s"tuple_extract_$neighborsColName"), "[\\(\\)]", "")) + + val result = + if (getExtractNearestNeighbor) + formatted + .withColumn( + s"split_$neighborsColName", + split(col(s"no_rounded_$neighborsColName"), ",")) + .withColumn( + "nearest_neighbor_id", + element_at(col(s"split_$neighborsColName"), 1).cast(IntegerType)) + .withColumn("nearest_neighbor_distance", element_at(col(s"split_$neighborsColName"), 2)) + else + formatted + + result.drop( + s"no_squared_$neighborsColName", + s"tuple_extract_$neighborsColName", + s"no_rounded_$neighborsColName", + s"split_$neighborsColName") + } + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + + override def transformSchema(schema: StructType): StructType = { + val documentSimilarityRankerAnnotators = Seq(AnnotatorType.DOC_SIMILARITY_RANKINGS) + + getInputCols.foreach { annotationColumn => + FinisherUtil.checkIfInputColsExist(getInputCols, schema) + FinisherUtil.checkIfAnnotationColumnIsSparkNLPAnnotation(schema, annotationColumn) + + /** Check if the annotationColumn has DocumentSimilarityRanker. It must be annotators: + * DocumentSimilarityRanker + */ + require( + documentSimilarityRankerAnnotators.contains( + schema(annotationColumn).metadata.getString("annotatorType")), + s"column [$annotationColumn] must be of type DocumentSimilarityRanker") + } + + val outputFields = schema.fields + + StructType(outputFields) + } +} + +object DocumentSimilarityRankerFinisher + extends DefaultParamsReadable[DocumentSimilarityRankerFinisher] diff --git a/src/test/scala/com/johnsnowlabs/nlp/similarity/DocumentSimilarityRankerTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/similarity/DocumentSimilarityRankerTestSpec.scala new file mode 100644 index 00000000000000..ccdd8294db6471 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/similarity/DocumentSimilarityRankerTestSpec.scala @@ -0,0 +1,275 @@ +package com.johnsnowlabs.nlp.similarity + +import com.johnsnowlabs.nlp.AnnotatorType.DOC_SIMILARITY_RANKINGS +import com.johnsnowlabs.nlp.annotators.Tokenizer +import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector +import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.embeddings.{AlbertEmbeddings, SentenceEmbeddings} +import com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher} +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.{SparkSession, functions} +import org.apache.spark.sql.functions.{col, element_at, size} +import org.scalatest.flatspec.AnyFlatSpec + +class DocumentSimilarityRankerTestSpec extends AnyFlatSpec { + val spark: SparkSession = ResourceHelper.spark + + "DocumentSimilarityRanker" should "should use brp to rank document similarity" taggedAs SlowTest in { + + val smallCorpus = spark + .createDataFrame( + List( + "First document, this is my first sentence. This is my second sentence.", + "Second document, this is my second sentence. This is my second sentence.", + "Third document, climate change is arguably one of the most pressing problems of our time.", + "Fourth document, climate change is definitely one of the most pressing problems of our time.", + "Fifth document, Florence in Italy, is among the most beautiful cities in Europe.", + "Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France.", + "Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.", + "Eighth document, the warmest place in France is the French Riviera coast in Southern France.") + .map(Tuple1(_))) + .toDF("text") + + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentence = new SentenceDetector() + .setInputCols("document") + .setOutputCol("sentence") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val embeddings = AlbertEmbeddings + .pretrained() + .setInputCols("sentence", "token") + .setOutputCol("embeddings") + + val embeddingsSentence = new SentenceEmbeddings() + .setInputCols(Array("document", "embeddings")) + .setOutputCol("sentence_embeddings") + .setPoolingStrategy("AVERAGE") + + val sentenceFinisher = new EmbeddingsFinisher() + .setInputCols("sentence_embeddings") + .setOutputCols("finished_sentence_embeddings") + .setCleanAnnotations(false) + + val docSimilarityRanker = new DocumentSimilarityRankerApproach() + .setInputCols("sentence_embeddings") + .setOutputCol(DOC_SIMILARITY_RANKINGS) + .setSimilarityMethod("brp") + .setNumberOfNeighbours(3) + .setVisibleDistances(true) + .setIdentityRanking(true) + + val documentSimilarityFinisher = new DocumentSimilarityRankerFinisher() + .setInputCols("doc_similarity_rankings") + .setOutputCols( + "finished_doc_similarity_rankings_id", + "finished_doc_similarity_rankings_neighbors") + .setExtractNearestNeighbor(true) + + val pipeline = new Pipeline() + .setStages( + Array( + documentAssembler, + sentence, + tokenizer, + embeddings, + embeddingsSentence, + sentenceFinisher, + docSimilarityRanker, + documentSimilarityFinisher)) + + val trainedPipelineModel = pipeline.fit(smallCorpus) + + val pipelineModelLoc = "./tmp_doc_sim_ranker_brp_pipeline" + trainedPipelineModel.write.overwrite().save(pipelineModelLoc) + val pipelineModel = PipelineModel.load(pipelineModelLoc) + + val transformed = pipelineModel.transform(smallCorpus) + + transformed.select("text", "finished_sentence_embeddings").show() + + // correct if not empty as inclusive query points are at distance 0.0 from themselves + assert(!transformed.where(col("nearest_neighbor_distance") === 0.0).rdd.isEmpty() == true) + } + + "DocumentSimilarityRanker" should "should use min hash to rank document similarity" taggedAs SlowTest in { + + val smallCorpus = spark + .createDataFrame( + List( + "First document, this is my first sentence. This is my second sentence.", + "Second document, this is my second sentence. This is my second sentence.", + "Third document, climate change is arguably one of the most pressing problems of our time.", + "Fourth document, climate change is definitely one of the most pressing problems of our time.", + "Fifth document, Florence in Italy, is among the most beautiful cities in Europe.", + "Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France.", + "Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.", + "Eighth document, the warmest place in France is the French Riviera coast in Southern France.") + .map(Tuple1(_))) + .toDF("text") + + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentence = new SentenceDetector() + .setInputCols("document") + .setOutputCol("sentence") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val embeddings = AlbertEmbeddings + .pretrained() + .setInputCols("sentence", "token") + .setOutputCol("embeddings") + + val embeddingsSentence = new SentenceEmbeddings() + .setInputCols(Array("document", "embeddings")) + .setOutputCol("sentence_embeddings") + .setPoolingStrategy("AVERAGE") + + val sentenceFinisher = new EmbeddingsFinisher() + .setInputCols("sentence_embeddings") + .setOutputCols("finished_sentence_embeddings") + .setCleanAnnotations(false) + + val docSimilarityRanker = new DocumentSimilarityRankerApproach() + .setInputCols("sentence_embeddings") + .setOutputCol(DOC_SIMILARITY_RANKINGS) + .setSimilarityMethod("mh") + .setNumberOfNeighbours(3) + .setVisibleDistances(true) + .setIdentityRanking(true) + + val documentSimilarityFinisher = new DocumentSimilarityRankerFinisher() + .setInputCols("doc_similarity_rankings") + .setOutputCols( + "finished_doc_similarity_rankings_id", + "finished_doc_similarity_rankings_neighbors") + .setExtractNearestNeighbor(true) + + val pipeline = new Pipeline() + .setStages( + Array( + documentAssembler, + sentence, + tokenizer, + embeddings, + embeddingsSentence, + sentenceFinisher, + docSimilarityRanker, + documentSimilarityFinisher)) + + val trainedPipelineModel = pipeline.fit(smallCorpus) + + val pipelineModelLoc = "./tmp_doc_sim_ranker_mh_pipeline" + trainedPipelineModel.write.overwrite().save(pipelineModelLoc) + val pipelineModel = PipelineModel.load(pipelineModelLoc) + + val transformed = pipelineModel.transform(smallCorpus) + + // correct if not empty as inclusive query points are at distance 0.0 from themselves + assert(!transformed.where(col("nearest_neighbor_distance") === 0.0).rdd.isEmpty() == true) + } + + "Databricks pipeline" should "should use min hash to rank document similarity" taggedAs SlowTest in { + import com.johnsnowlabs.nlp.AnnotatorType.DOC_SIMILARITY_RANKINGS + import com.johnsnowlabs.nlp.annotators.Tokenizer + import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector + import com.johnsnowlabs.nlp.annotators.similarity.DocumentSimilarityRankerApproach + import com.johnsnowlabs.nlp.base.DocumentAssembler + import com.johnsnowlabs.nlp.embeddings.{AlbertEmbeddings, SentenceEmbeddings} + import com.johnsnowlabs.nlp.finisher.DocumentSimilarityRankerFinisher + import com.johnsnowlabs.nlp.util.io.ResourceHelper + import com.johnsnowlabs.nlp.EmbeddingsFinisher + import org.apache.spark.ml.{Pipeline, PipelineModel} + + val smallCorpus = spark + .createDataFrame( + List( + "First document, this is my first sentence. This is my second sentence.", + "Second document, this is my second sentence. This is my second sentence.", + "Third document, climate change is arguably one of the most pressing problems of our time.", + "Fourth document, climate change is definitely one of the most pressing problems of our time.", + "Fifth document, Florence in Italy, is among the most beautiful cities in Europe.", + "Sixth document, Florence in Italy, is a very beautiful city in Europe like Lyon in France.", + "Seventh document, the French Riviera is the Mediterranean coastline of the southeast corner of France.", + "Eighth document, the warmest place in France is the French Riviera coast in Southern France.") + .map(Tuple1(_))) + .toDF("text") + + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentence = new SentenceDetector() + .setInputCols("document") + .setOutputCol("sentence") + + val tokenizer = new Tokenizer() + .setInputCols(Array("document")) + .setOutputCol("token") + + val embeddings = AlbertEmbeddings + .pretrained() + .setInputCols("sentence", "token") + .setOutputCol("embeddings") + + val embeddingsSentence = new SentenceEmbeddings() + .setInputCols(Array("document", "embeddings")) + .setOutputCol("sentence_embeddings") + .setPoolingStrategy("AVERAGE") + + val sentenceFinisher = new EmbeddingsFinisher() + .setInputCols("sentence_embeddings") + .setOutputCols("finished_sentence_embeddings") + .setCleanAnnotations(false) + + val docSimilarityRanker = new DocumentSimilarityRankerApproach() + .setInputCols("sentence_embeddings") + .setOutputCol(DOC_SIMILARITY_RANKINGS) + .setSimilarityMethod("brp") + .setNumberOfNeighbours(3) + .setVisibleDistances(true) + .setIdentityRanking(true) + + val documentSimilarityFinisher = new DocumentSimilarityRankerFinisher() + .setInputCols("doc_similarity_rankings") + .setOutputCols( + "finished_doc_similarity_rankings_id", + "finished_doc_similarity_rankings_neighbors") + .setExtractNearestNeighbor(true) + + val pipeline = new Pipeline() + .setStages( + Array( + documentAssembler, + sentence, + tokenizer, + embeddings, + embeddingsSentence, + sentenceFinisher, + docSimilarityRanker, + documentSimilarityFinisher)) + + val transformed = pipeline.fit(smallCorpus).transform(smallCorpus) + + transformed + .select("text", "sentence_embeddings.embeddings") + .withColumn("extracted_embeddings", element_at(col("embeddings"), 1)) + .withColumn("embeddings_size", size(col("extracted_embeddings"))) + .show(10, false) + } +}