diff --git a/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala b/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala index 7900b12..93b6ade 100644 --- a/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala +++ b/example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala @@ -44,7 +44,7 @@ object AlgoPerformanceTest { .builder() .withMetaAddress("127.0.0.0.1:9559") .withTimeout(6000) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig .builder() diff --git a/example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala b/example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala index 3feaea5..505f007 100644 --- a/example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala +++ b/example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala @@ -39,7 +39,7 @@ object DeepQueryTest { .builder() .withMetaAddress("192.168.15.5:9559") .withTimeout(6000) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig .builder() diff --git a/example/src/main/scala/com/vesoft/nebula/algorithm/ReadData.scala b/example/src/main/scala/com/vesoft/nebula/algorithm/ReadData.scala index 8f0e070..7efd81e 100644 --- a/example/src/main/scala/com/vesoft/nebula/algorithm/ReadData.scala +++ b/example/src/main/scala/com/vesoft/nebula/algorithm/ReadData.scala @@ -67,7 +67,7 @@ object ReadData { .builder() .withMetaAddress("127.0.0.1:9559") .withTimeout(6000) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig .builder() diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala index 4b41472..13ec599 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala @@ -6,47 +6,10 @@ package com.vesoft.nebula.algorithm import com.vesoft.nebula.algorithm.config.Configs.Argument -import com.vesoft.nebula.algorithm.config.{ - AlgoConfig, - BetweennessConfig, - BfsConfig, - CcConfig, - CoefficientConfig, - Configs, - DfsConfig, - HanpConfig, - JaccardConfig, - KCoreConfig, - LPAConfig, - LouvainConfig, - Node2vecConfig, - PRConfig, - ShortestPathConfig, - SparkConfig, - DegreeStaticConfig -} -import com.vesoft.nebula.algorithm.lib.{ - BetweennessCentralityAlgo, - BfsAlgo, - ClosenessAlgo, - ClusteringCoefficientAlgo, - ConnectedComponentsAlgo, - DegreeStaticAlgo, - DfsAlgo, - GraphTriangleCountAlgo, - HanpAlgo, - JaccardAlgo, - KCoreAlgo, - LabelPropagationAlgo, - LouvainAlgo, - Node2vecAlgo, - PageRankAlgo, - ShortestPathAlgo, - StronglyConnectedComponentsAlgo, - TriangleCountAlgo -} -import com.vesoft.nebula.algorithm.reader.{CsvReader, JsonReader, NebulaReader} -import com.vesoft.nebula.algorithm.writer.{CsvWriter, NebulaWriter, TextWriter} +import com.vesoft.nebula.algorithm.config._ +import com.vesoft.nebula.algorithm.lib._ +import com.vesoft.nebula.algorithm.reader.{CsvReader, DataReader, JsonReader, NebulaReader} +import com.vesoft.nebula.algorithm.writer.{AlgoWriter, CsvWriter, NebulaWriter, TextWriter} import org.apache.commons.math3.ode.UnknownParameterException import org.apache.log4j.Logger import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} @@ -114,26 +77,8 @@ object Main { private[this] def createDataSource(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { - val dataSource = configs.dataSourceSinkEntry.source - val dataSet: Dataset[Row] = dataSource.toLowerCase match { - case "nebula" => { - val reader = new NebulaReader(spark, configs, partitionNum) - reader.read() - } - case "nebula-ngql" => { - val reader = new NebulaReader(spark, configs, partitionNum) - reader.readNgql() - } - case "csv" => { - val reader = new CsvReader(spark, configs, partitionNum) - reader.read() - } - case "json" => { - val reader = new JsonReader(spark, configs, partitionNum) - reader.read() - } - } - dataSet + val dataSource = DataReader.make(configs) + dataSource.read(spark, configs, partitionNum) } /** @@ -149,99 +94,63 @@ object Main { configs: Configs, dataSet: DataFrame): DataFrame = { val hasWeight = configs.dataSourceSinkEntry.hasWeight - val algoResult = { - algoName.toLowerCase match { - case "pagerank" => { - val pageRankConfig = PRConfig.getPRConfig(configs) - PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight) - } - case "louvain" => { - val louvainConfig = LouvainConfig.getLouvainConfig(configs) - LouvainAlgo(spark, dataSet, louvainConfig, hasWeight) - } - case "connectedcomponent" => { - val ccConfig = CcConfig.getCcConfig(configs) - ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight) - } - case "labelpropagation" => { - val lpaConfig = LPAConfig.getLPAConfig(configs) - LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight) - } - case "shortestpaths" => { - val spConfig = ShortestPathConfig.getShortestPathConfig(configs) - ShortestPathAlgo(spark, dataSet, spConfig, hasWeight) - } - case "degreestatic" => { - val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs) - DegreeStaticAlgo(spark, dataSet, dsConfig) - } - case "kcore" => { - val kCoreConfig = KCoreConfig.getKCoreConfig(configs) - KCoreAlgo(spark, dataSet, kCoreConfig) - } - case "stronglyconnectedcomponent" => { - val ccConfig = CcConfig.getCcConfig(configs) - StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight) - } - case "betweenness" => { - val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs) - BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight) - } - case "trianglecount" => { - TriangleCountAlgo(spark, dataSet) - } - case "graphtrianglecount" => { - GraphTriangleCountAlgo(spark, dataSet) - } - case "clusteringcoefficient" => { - val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs) - ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig) - } - case "closeness" => { - ClosenessAlgo(spark, dataSet, hasWeight) - } - case "hanp" => { - val hanpConfig = HanpConfig.getHanpConfig(configs) - HanpAlgo(spark, dataSet, hanpConfig, hasWeight) - } - case "node2vec" => { - val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs) - Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight) - } - case "bfs" => { - val bfsConfig = BfsConfig.getBfsConfig(configs) - BfsAlgo(spark, dataSet, bfsConfig) - } - case "dfs" => { - val dfsConfig = DfsConfig.getDfsConfig(configs) - DfsAlgo(spark, dataSet, dfsConfig) - } - case "jaccard" => { - val jaccardConfig = JaccardConfig.getJaccardConfig(configs) - JaccardAlgo(spark, dataSet, jaccardConfig) - } - case _ => throw new UnknownParameterException("unknown executeAlgo name.") - } + AlgorithmType.mapping.getOrElse(algoName.toLowerCase, throw new UnknownParameterException("unknown executeAlgo name.")) match { + case AlgorithmType.Bfs => + val bfsConfig = BfsConfig.getBfsConfig(configs) + BfsAlgo(spark, dataSet, bfsConfig) + case AlgorithmType.Closeness => + ClosenessAlgo(spark, dataSet, hasWeight) + case AlgorithmType.ClusteringCoefficient => + val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs) + ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig) + case AlgorithmType.ConnectedComponents => + val ccConfig = CcConfig.getCcConfig(configs) + ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight) + case AlgorithmType.DegreeStatic => + val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs) + DegreeStaticAlgo(spark, dataSet, dsConfig) + case AlgorithmType.Dfs => + val dfsConfig = DfsConfig.getDfsConfig(configs) + DfsAlgo(spark, dataSet, dfsConfig) + case AlgorithmType.GraphTriangleCount => + GraphTriangleCountAlgo(spark, dataSet) + case AlgorithmType.Hanp => + val hanpConfig = HanpConfig.getHanpConfig(configs) + HanpAlgo(spark, dataSet, hanpConfig, hasWeight) + case AlgorithmType.Jaccard => + val jaccardConfig = JaccardConfig.getJaccardConfig(configs) + JaccardAlgo(spark, dataSet, jaccardConfig) + case AlgorithmType.KCore => + val kCoreConfig = KCoreConfig.getKCoreConfig(configs) + KCoreAlgo(spark, dataSet, kCoreConfig) + case AlgorithmType.LabelPropagation => + val lpaConfig = LPAConfig.getLPAConfig(configs) + LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight) + case AlgorithmType.Louvain => + val louvainConfig = LouvainConfig.getLouvainConfig(configs) + LouvainAlgo(spark, dataSet, louvainConfig, hasWeight) + case AlgorithmType.Node2vec => + val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs) + Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight) + case AlgorithmType.PageRank => + val pageRankConfig = PRConfig.getPRConfig(configs) + PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight) + case AlgorithmType.ShortestPath => + val spConfig = ShortestPathConfig.getShortestPathConfig(configs) + ShortestPathAlgo(spark, dataSet, spConfig, hasWeight) + case AlgorithmType.StronglyConnectedComponents => + val ccConfig = CcConfig.getCcConfig(configs) + StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight) + case AlgorithmType.TriangleCount => + TriangleCountAlgo(spark, dataSet) + case AlgorithmType.BetweennessCentrality => + val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs) + BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight) } - algoResult } private[this] def saveAlgoResult(algoResult: DataFrame, configs: Configs): Unit = { - val dataSink = configs.dataSourceSinkEntry.sink - dataSink.toLowerCase match { - case "nebula" => { - val writer = new NebulaWriter(algoResult, configs) - writer.write() - } - case "csv" => { - val writer = new CsvWriter(algoResult, configs) - writer.write() - } - case "text" => { - val writer = new TextWriter(algoResult, configs) - writer.write() - } - case _ => throw new UnsupportedOperationException("unsupported data sink") - } + val writer = AlgoWriter.make(configs) + writer.write(algoResult, configs) } } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/AlgorithmType.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/AlgorithmType.scala new file mode 100644 index 0000000..7ff4d13 --- /dev/null +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/AlgorithmType.scala @@ -0,0 +1,74 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.algorithm.lib + +/** + * + * @author 梦境迷离 + * @version 1.0,2023/9/12 + */ +sealed trait AlgorithmType { + self => + def stringify: String = self match { + case AlgorithmType.Bfs => "bfs" + case AlgorithmType.Closeness => "closeness" + case AlgorithmType.ClusteringCoefficient => "clusteringcoefficient" + case AlgorithmType.ConnectedComponents => "connectedcomponent" + case AlgorithmType.DegreeStatic => "degreestatic" + case AlgorithmType.Dfs => "dfs" + case AlgorithmType.GraphTriangleCount => "graphtrianglecount" + case AlgorithmType.Hanp => "hanp" + case AlgorithmType.Jaccard => "jaccard" + case AlgorithmType.KCore => "kcore" + case AlgorithmType.LabelPropagation => "labelpropagation" + case AlgorithmType.Louvain => "louvain" + case AlgorithmType.Node2vec => "node2vec" + case AlgorithmType.PageRank => "pagerank" + case AlgorithmType.ShortestPath => "shortestpaths" + case AlgorithmType.StronglyConnectedComponents => "stronglyconnectedcomponent" + case AlgorithmType.TriangleCount => "trianglecount" + case AlgorithmType.BetweennessCentrality => "betweenness" + } +} +object AlgorithmType { + lazy val mapping: Map[String, AlgorithmType] = Map( + Bfs.stringify -> Bfs, + Closeness.stringify -> Closeness, + ClusteringCoefficient.stringify -> ClusteringCoefficient, + ConnectedComponents.stringify -> ConnectedComponents, + DegreeStatic.stringify -> DegreeStatic, + GraphTriangleCount.stringify -> GraphTriangleCount, + Hanp.stringify -> Hanp, + Jaccard.stringify -> Jaccard, + KCore.stringify -> KCore, + LabelPropagation.stringify -> LabelPropagation, + Louvain.stringify -> Louvain, + Node2vec.stringify -> Node2vec, + PageRank.stringify -> PageRank, + ShortestPath.stringify -> ShortestPath, + StronglyConnectedComponents.stringify -> StronglyConnectedComponents, + TriangleCount.stringify -> TriangleCount, + BetweennessCentrality.stringify -> BetweennessCentrality + ) + object BetweennessCentrality extends AlgorithmType + object Bfs extends AlgorithmType + object Closeness extends AlgorithmType + object ClusteringCoefficient extends AlgorithmType + object ConnectedComponents extends AlgorithmType + object DegreeStatic extends AlgorithmType + object Dfs extends AlgorithmType + object GraphTriangleCount extends AlgorithmType + object Hanp extends AlgorithmType + object Jaccard extends AlgorithmType + object KCore extends AlgorithmType + object LabelPropagation extends AlgorithmType + object Louvain extends AlgorithmType + object Node2vec extends AlgorithmType + object PageRank extends AlgorithmType + object ShortestPath extends AlgorithmType + object StronglyConnectedComponents extends AlgorithmType + object TriangleCount extends AlgorithmType +} diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala index 431db22..b1cbe28 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala @@ -12,13 +12,27 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import scala.collection.mutable.ListBuffer -abstract class DataReader(spark: SparkSession, configs: Configs) { - def read(): DataFrame +abstract class DataReader { + val tpe: ReaderType + def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame +} +object DataReader { + def make(configs: Configs): DataReader = { + ReaderType.mapping + .get(configs.dataSourceSinkEntry.source.toLowerCase) + .collect { + case ReaderType.json => new JsonReader + case ReaderType.nebulaNgql => new NebulaNgqlReader + case ReaderType.nebula => new NebulaReader + case ReaderType.csv => new CsvReader + } + .getOrElse(throw new UnsupportedOperationException("unsupported reader")) + } } -class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) - extends DataReader(spark, configs) { - override def read(): DataFrame = { +class NebulaReader extends DataReader { + override val tpe: ReaderType = ReaderType.nebula + override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { val metaAddress = configs.nebulaConfig.readConfigEntry.address val space = configs.nebulaConfig.readConfigEntry.space val labels = configs.nebulaConfig.readConfigEntry.labels @@ -29,7 +43,7 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) NebulaConnectionConfig .builder() .withMetaAddress(metaAddress) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() val noColumn = weights.isEmpty @@ -66,7 +80,12 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) dataset } - def readNgql(): DataFrame = { +} +final class NebulaNgqlReader extends NebulaReader { + + override val tpe: ReaderType = ReaderType.nebulaNgql + + override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { val metaAddress = configs.nebulaConfig.readConfigEntry.address val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress val space = configs.nebulaConfig.readConfigEntry.space @@ -80,7 +99,7 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) .builder() .withMetaAddress(metaAddress) .withGraphAddress(graphAddress) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() var dataset: DataFrame = null @@ -113,11 +132,12 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) } dataset } + } -class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String) - extends DataReader(spark, configs) { - override def read(): DataFrame = { +final class CsvReader extends DataReader { + override val tpe: ReaderType = ReaderType.csv + override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { val delimiter = configs.localConfigEntry.delimiter val header = configs.localConfigEntry.header val localPath = configs.localConfigEntry.filePath @@ -132,7 +152,7 @@ class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String) val weight = configs.localConfigEntry.weight val src = configs.localConfigEntry.srcId val dst = configs.localConfigEntry.dstId - if (configs.dataSourceSinkEntry.hasWeight && weight != null && !weight.trim.isEmpty) { + if (configs.dataSourceSinkEntry.hasWeight && weight != null && weight.trim.nonEmpty) { data.select(src, dst, weight) } else { data.select(src, dst) @@ -143,10 +163,9 @@ class CsvReader(spark: SparkSession, configs: Configs, partitionNum: String) data } } - -class JsonReader(spark: SparkSession, configs: Configs, partitionNum: String) - extends DataReader(spark, configs) { - override def read(): DataFrame = { +final class JsonReader extends DataReader { + override val tpe: ReaderType = ReaderType.json + override def read(spark: SparkSession, configs: Configs, partitionNum: String): DataFrame = { val localPath = configs.localConfigEntry.filePath val data = spark.read.json(localPath) val partition = partitionNum.toInt @@ -154,7 +173,7 @@ class JsonReader(spark: SparkSession, configs: Configs, partitionNum: String) val weight = configs.localConfigEntry.weight val src = configs.localConfigEntry.srcId val dst = configs.localConfigEntry.dstId - if (configs.dataSourceSinkEntry.hasWeight && weight != null && !weight.trim.isEmpty) { + if (configs.dataSourceSinkEntry.hasWeight && weight != null && weight.trim.nonEmpty) { data.select(src, dst, weight) } else { data.select(src, dst) diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/ReaderType.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/ReaderType.scala new file mode 100644 index 0000000..ca1d101 --- /dev/null +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/ReaderType.scala @@ -0,0 +1,33 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.algorithm.reader + +/** + * + * @author 梦境迷离 + * @version 1.0,2023/9/12 + */ +sealed trait ReaderType { + self => + def stringify: String = self match { + case ReaderType.json => "json" + case ReaderType.nebulaNgql => "nebula-ngql" + case ReaderType.nebula => "nebula" + case ReaderType.csv => "csv" + } +} +object ReaderType { + lazy val mapping: Map[String, ReaderType] = Map( + json.stringify -> json, + nebulaNgql.stringify -> nebulaNgql, + nebula.stringify -> nebula, + csv.stringify -> csv + ) + object json extends ReaderType + object nebulaNgql extends ReaderType + object nebula extends ReaderType + object csv extends ReaderType +} diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/AlgoWriter.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/AlgoWriter.scala index b56831d..e4da34d 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/AlgoWriter.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/AlgoWriter.scala @@ -10,12 +10,24 @@ import com.vesoft.nebula.connector.{NebulaConnectionConfig, WriteMode, WriteNebu import com.vesoft.nebula.algorithm.config.{AlgoConstants, Configs} import org.apache.spark.sql.DataFrame -abstract class AlgoWriter(data: DataFrame, configs: Configs) { - def write(): Unit +abstract class AlgoWriter { + val tpe:WriterType + def write(data: DataFrame, configs: Configs): Unit +} +object AlgoWriter { + def make(configs: Configs): AlgoWriter = { + WriterType.mapping.get(configs.dataSourceSinkEntry.sink.toLowerCase).collect { + case WriterType.text => new TextWriter + case WriterType.nebula => new NebulaWriter + case WriterType.csv => new CsvWriter + }.getOrElse(throw new UnsupportedOperationException("unsupported writer")) + + } } -class NebulaWriter(data: DataFrame, configs: Configs) extends AlgoWriter(data, configs) { - override def write(): Unit = { +final class NebulaWriter extends AlgoWriter { + override val tpe: WriterType = WriterType.nebula + override def write(data: DataFrame, configs: Configs): Unit = { val graphAddress = configs.nebulaConfig.writeConfigEntry.graphAddress val metaAddress = configs.nebulaConfig.writeConfigEntry.metaAddress val space = configs.nebulaConfig.writeConfigEntry.space @@ -30,7 +42,7 @@ class NebulaWriter(data: DataFrame, configs: Configs) extends AlgoWriter(data, c .builder() .withMetaAddress(metaAddress) .withGraphAddress(graphAddress) - .withConenctionRetry(2) + .withConnectionRetry(2) .build() val nebulaWriteVertexConfig = WriteNebulaVertexConfig .builder() @@ -47,15 +59,17 @@ class NebulaWriter(data: DataFrame, configs: Configs) extends AlgoWriter(data, c } } -class CsvWriter(data: DataFrame, configs: Configs) extends AlgoWriter(data, configs) { - override def write(): Unit = { +final class CsvWriter extends AlgoWriter { + override val tpe: WriterType = WriterType.csv + override def write(data: DataFrame, configs: Configs): Unit = { val resultPath = configs.localConfigEntry.resultPath data.write.option("header", true).csv(resultPath) } } -class TextWriter(data: DataFrame, configs: Configs) extends AlgoWriter(data, configs) { - override def write(): Unit = { +final class TextWriter extends AlgoWriter { + override val tpe: WriterType = WriterType.text + override def write(data: DataFrame, configs: Configs): Unit = { val resultPath = configs.localConfigEntry.resultPath data.write.option("header", true).text(resultPath) } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/WriterType.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/WriterType.scala new file mode 100644 index 0000000..84a7839 --- /dev/null +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/writer/WriterType.scala @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.algorithm.writer + +/** + * + * @author 梦境迷离 + * @version 1.0,2023/9/12 + */ +sealed trait WriterType { + self => + def stringify: String = self match { + case WriterType.text => "text" + case WriterType.nebula => "nebula" + case WriterType.csv => "csv" + } +} +object WriterType { + lazy val mapping: Map[String, WriterType] = Map( + text.stringify -> text, + nebula.stringify -> nebula, + csv.stringify -> csv + ) + object text extends WriterType + object nebula extends WriterType + object csv extends WriterType +}