diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/NebulaUtil.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/NebulaUtil.scala index aab0578..31a40f4 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/NebulaUtil.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/NebulaUtil.scala @@ -21,15 +21,10 @@ object NebulaUtil { */ def loadInitGraph(dataSet: Dataset[Row], hasWeight: Boolean): Graph[None.type, Double] = { implicit val encoder: Encoder[Edge[Double]] = org.apache.spark.sql.Encoders.kryo[Edge[Double]] - val edges: RDD[Edge[Double]] = dataSet - .map(row => { - if (hasWeight) { - Edge(row.get(0).toString.toLong, row.get(1).toString.toLong, row.get(2).toString.toDouble) - } else { - Edge(row.get(0).toString.toLong, row.get(1).toString.toLong, 1.0) - } - })(encoder) - .rdd + val edges: RDD[Edge[Double]] = dataSet.map { row => + val attr = if (hasWeight) row.get(2).toString.toDouble else 1.0 + Edge(row.get(0).toString.toLong, row.get(1).toString.toLong, attr) + }(encoder).rdd Graph.fromEdges(edges, None) } @@ -42,11 +37,7 @@ object NebulaUtil { * * @return validate result path */ - def getResultPath(path: String, algorithmName: String): String = { - var resultFilePath = path - if (!resultFilePath.endsWith("/")) { - resultFilePath = resultFilePath + "/" - } - resultFilePath + algorithmName - } + def getResultPath(path: String, algorithmName: String): String = + if (path.endsWith("/")) s"$path$algorithmName" + else s"$path/$algorithmName" }